T27256: 当前队列中位数
思路:
- 题意即维护一个队列,有如下三种操作:(1) 在队尾添加;(2) 弹出队首;(3) 查询队列中的数集的中位数。
- 现在把队列的壳去掉,问题变为对一个可重集合的操作:有如下三种操作:(1) 插入一个数;(2) 删除一个数;(3) 查询集合的中位数。
- 思路 A:最好写的做法就是用 vector 维护一个有序数列,每次 lower_bound 到插入 / 删除位置并用 vector.insert / vector.erase 执行相应操作。Python 玩家的话可以使用 bisect。但这样做的时间复杂度为
,本不应该通过。 虽然我赛时过了 /youl - 思路 B:考虑“中位数”的性质:注意到我们总是可以将原集合分成两个集合
,其中:(1) 中的元素都不大于中位数, 中的元素都不小于中位数;(2) 。分别用大根堆和小根堆维护两个集合。 - 插入时,不难得知插入到哪个集合中可以维持性质 (1),插入完成后若不再满足 (2),则将其中一方的堆顶弹出,并放到另一方的堆顶。
- 删除时,打个标记,并维护
的实际值即可。 - 查询时,若
则答案为两个堆顶的平均数,若 则答案为 的堆顶。 - 思路 C:注意到平衡树可以胜任以上三种操作(第三种是求第 k 小的特殊情况)。当然,这题
并不太大,写值域线段树也可以通过。
代码(思路 A):
cpp
#include <algorithm>
#include <queue>
#include <vector>
#include <cstdio>
using namespace std;
char op[7];
queue<int> q;
vector<int> v;
void insert(int x){
v.insert(lower_bound(v.begin(), v.end(), x), x);
}
void erase(int x){
v.erase(lower_bound(v.begin(), v.end(), x));
}
int main(){
int n, len = 0;
scanf("%d", &n);
for (int i = 1; i <= n; i++){
scanf("%s", op);
if (op[0] == 'a'){
int x;
scanf("%d", &x);
len++;
q.push(x);
insert(x);
} else if (op[0] == 'd'){
len--;
erase(q.front());
q.pop();
} else {
if (len % 2 == 1){
printf("%d\n", v[len / 2]);
} else {
int sum = v[len / 2] + v[len / 2 - 1];
printf("%d", sum / 2);
if (sum % 2 == 1) printf(".5");
printf("\n");
}
}
}
return 0;
}代码(思路 C,值域线段树):
cpp
#include <queue>
#include <cstdio>
using namespace std;
struct Node {
int ls;
int rs;
int size;
};
int root = 0, id = 0;
char op[7];
Node tree[3100001];
queue<int> q;
void update(int x){
tree[x].size = tree[tree[x].ls].size + tree[tree[x].rs].size;
}
void add(int &x, int l, int r, int pos, int val){
if (x == 0) x = ++id;
if (l == r){
tree[x].size += val;
return;
}
int mid = (l + r) >> 1;
if (pos <= mid){
add(tree[x].ls, l, mid, pos, val);
} else {
add(tree[x].rs, mid + 1, r, pos, val);
}
update(x);
}
int get_kth_number(int x, int l, int r, int k){
if (l == r) return l;
int ls = tree[x].ls;
if (k <= tree[ls].size) return get_kth_number(ls, l, (l + r) >> 1, k);
return get_kth_number(tree[x].rs, ((l + r) >> 1) + 1, r, k - tree[ls].size);
}
int main(){
int n, len = 0;
scanf("%d", &n);
for (int i = 1; i <= n; i++){
scanf("%s", op);
if (op[0] == 'a'){
int x;
scanf("%d", &x);
len++;
q.push(x);
add(root, 0, 1e9, x, 1);
} else if (op[0] == 'd'){
len--;
add(root, 0, 1e9, q.front(), -1);
q.pop();
} else {
if (len % 2 == 1){
printf("%d\n", get_kth_number(root, 0, 1e9, (len + 1) / 2));
} else {
int sum = get_kth_number(root, 0, 1e9, len / 2) + get_kth_number(root, 0, 1e9, len / 2 + 1);
printf("%d", sum / 2);
if (sum % 2 == 1) printf(".5");
printf("\n");
}
}
}
return 0;
}