Skip to content

T27256: 当前队列中位数

思路:

  • 题意即维护一个队列,有如下三种操作:(1) 在队尾添加;(2) 弹出队首;(3) 查询队列中的数集的中位数。
  • 现在把队列的壳去掉,问题变为对一个可重集合的操作:有如下三种操作:(1) 插入一个数;(2) 删除一个数;(3) 查询集合的中位数。
  • 思路 A:最好写的做法就是用 vector 维护一个有序数列,每次 lower_bound 到插入 / 删除位置并用 vector.insert / vector.erase 执行相应操作。Python 玩家的话可以使用 bisect。但这样做的时间复杂度为 O(n2),本不应该通过。虽然我赛时过了 /youl
  • 思路 B:考虑“中位数”的性质:注意到我们总是可以将原集合分成两个集合 P,Q,其中:(1) P 中的元素都不大于中位数,Q 中的元素都不小于中位数;(2) |P||Q|{0,1}。分别用大根堆和小根堆维护两个集合。
  • 插入时,不难得知插入到哪个集合中可以维持性质 (1),插入完成后若不再满足 (2),则将其中一方的堆顶弹出,并放到另一方的堆顶。
  • 删除时,打个标记,并维护 |P|,|Q| 的实际值即可。
  • 查询时,若 |P|=|Q| 则答案为两个堆顶的平均数,若 |P||Q|=1 则答案为 P 的堆顶。
  • 思路 C:注意到平衡树可以胜任以上三种操作(第三种是求第 k 小的特殊情况)。当然,这题 n 并不太大,写值域线段树也可以通过。

代码(思路 A):

cpp
#include <algorithm>
#include <queue>
#include <vector>
#include <cstdio>

using namespace std;

char op[7];
queue<int> q;
vector<int> v;

void insert(int x){
	v.insert(lower_bound(v.begin(), v.end(), x), x);
}

void erase(int x){
	v.erase(lower_bound(v.begin(), v.end(), x));
}

int main(){
	int n, len = 0;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++){
		scanf("%s", op);
		if (op[0] == 'a'){
			int x;
			scanf("%d", &x);
			len++;
			q.push(x);
			insert(x);
		} else if (op[0] == 'd'){
			len--;
			erase(q.front());
			q.pop();
		} else {
			if (len % 2 == 1){
				printf("%d\n", v[len / 2]);
			} else {
				int sum = v[len / 2] + v[len / 2 - 1];
				printf("%d", sum / 2);
				if (sum % 2 == 1) printf(".5");
				printf("\n");
			}
		}
	}
	return 0;
}

代码(思路 C,值域线段树):

cpp
#include <queue>
#include <cstdio>

using namespace std;

struct Node {
	int ls;
	int rs;
	int size;
};

int root = 0, id = 0;
char op[7];
Node tree[3100001];
queue<int> q;

void update(int x){
	tree[x].size = tree[tree[x].ls].size + tree[tree[x].rs].size;
}

void add(int &x, int l, int r, int pos, int val){
	if (x == 0) x = ++id;
	if (l == r){
		tree[x].size += val;
		return;
	}
	int mid = (l + r) >> 1;
	if (pos <= mid){
		add(tree[x].ls, l, mid, pos, val);
	} else {
		add(tree[x].rs, mid + 1, r, pos, val);
	}
	update(x);
}

int get_kth_number(int x, int l, int r, int k){
	if (l == r) return l;
	int ls = tree[x].ls;
	if (k <= tree[ls].size) return get_kth_number(ls, l, (l + r) >> 1, k);
	return get_kth_number(tree[x].rs, ((l + r) >> 1) + 1, r, k - tree[ls].size);
}

int main(){
	int n, len = 0;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++){
		scanf("%s", op);
		if (op[0] == 'a'){
			int x;
			scanf("%d", &x);
			len++;
			q.push(x);
			add(root, 0, 1e9, x, 1);
		} else if (op[0] == 'd'){
			len--;
			add(root, 0, 1e9, q.front(), -1);
			q.pop();
		} else {
			if (len % 2 == 1){
				printf("%d\n", get_kth_number(root, 0, 1e9, (len + 1) / 2));
			} else {
				int sum = get_kth_number(root, 0, 1e9, len / 2) + get_kth_number(root, 0, 1e9, len / 2 + 1);
				printf("%d", sum / 2);
				if (sum % 2 == 1) printf(".5");
				printf("\n");
			}
		}
	}
	return 0;
}