20018: 蚂蚁王国的越野跑
merge sort, binary indexed tree, binary search, http://cs101.openjudge.cn/practice/20018
为了促进蚂蚁家族身体健康,提高蚁族健身意识,蚂蚁王国举行了越野跑。假设越野跑共有N个蚂蚁参加,在一条笔直的道路上进行。N个蚂蚁在起点处站成一列,相邻两个蚂蚁之间保持一定的间距。比赛开始后,N个蚂蚁同时沿着道路向相同的方向跑去。换句话说,这N个蚂蚁可以看作x轴上的N个点,在比赛开始后,它们同时向X轴正方向移动。假设越野跑的距离足够远,这N个蚂蚁的速度有的不相同有的相同且保持匀速运动,那么会有多少对参赛者之间发生“赶超”的事件呢?此题结果比较大,需要定义long long类型。请看备注。

输入
第一行1个整数N。 第2… N +1行:N 个非负整数,按从前到后的顺序给出每个蚂蚁的跑步速度。对于50%的数据,2<=N<=1000。对于100%的数据,2<=N<=100000。
输出
一个整数,表示有多少对参赛者之间发生赶超事件。
样例输入
sample1 input:
5
1
5
10
7
6
sample2 input:
5
1
5
5
7
6样例输出
sample1 output:
7
sample2 output:
8提示
我们把这5个蚂蚁依次编号为A,B,C,D,E,假设速度分别为1,5,5,7,6。在跑步过程中:B,C,D,E均会超过A,因为他们的速度都比A快;D,E都会超过B,C,因为他们的速度都比B,C快;D,E之间不会发生赶超,因为速度快的起跑时就在前边;B,C之间不会发生赶超,因为速度一样,在前面的就一直在前面。
考虑归并排序的思想。
此题结果比较大,需要定义long long类型,其输出格式为printf("%lld",x); long long,有符号 64位整数,所占8个字节(Byte) -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807
出题人太随意了。题目中的样例数据是 2 组样例数据。
这题本质上是 统计逆序对。
关键观察
蚂蚁一开始按位置从前到后排好: 若前面的蚂蚁 i 速度 v[i] 小于 后面的蚂蚁 j 速度 v[j],那么 j 一定会追上 i。
因此需要统计:i < j 且 v[i] < v[j] 的对数。
直觉解(bisect)
复杂度:O(N^2)**
pythonfrom bisect import bisect_left n=int(input()) v=[] ans=0 for i in range(n): p=int(input()) index=bisect_left(v,p) v.insert(index,p) ans+=index print(ans)bisect_left 的作用
index = bisect_left(v, p)含义:在有序数组
v中找到 p 应该插入的位置并且这个位置满足:
v[0:index] < p v[index:] >= p因此:
index = 已有元素中 < p 的个数这正好就是需要统计的数量。
分治解(归并统计)
而 归并排序统计逆序对通常统计的是:i < j 且 v[i] > v[j]
所以只要把速度 取负数,就能直接用逆序对模板。
时间复杂度:O(N log N),适合 N ≤ 100000。
Python实现(归并排序统计)
import sys
sys.setrecursionlimit(1000000)
def merge_sort(arr):
n = len(arr)
if n <= 1:
return arr, 0
mid = n // 2
left, cnt1 = merge_sort(arr[:mid])
right, cnt2 = merge_sort(arr[mid:])
i = j = 0
merged = []
cnt = cnt1 + cnt2
while i < len(left) and j < len(right):
if left[i] <= right[j]:
merged.append(left[i])
i += 1
else:
merged.append(right[j])
cnt += len(left) - i
j += 1
merged.extend(left[i:])
merged.extend(right[j:])
return merged, cnt
def main():
input = sys.stdin.readline
n = int(input())
v = [int(input()) for _ in range(n)]
# 转成负数,把 v[i] < v[j] 转成逆序对
arr = [-x for x in v]
_, ans = merge_sort(arr)
print(ans)
if __name__ == "__main__":
main()思维图(非常关键)
原序列:
位置: 1 2 3 4 5
速度: 1 5 10 7 6追赶发生条件:
前面慢,后面快
v[i] < v[j]例如:
1 < 5
1 < 10
1 < 7
1 < 6
...统计这种对数即可。
为什么题目提示归并排序
因为:
- N 最大 100000
- 暴力 O(N^2) 会 超时
- 归并统计逆序对是经典 O(N log N)
数据结构解(BIT)
下面给出 BIT(树状数组)版本。这个题其实非常适合讲 树状数组统计顺序对。
我们要统计:i < j 且 v[i] < v[j]
思路是 从左到右扫描:
- 已经扫描过的蚂蚁速度放进树状数组
- 当前蚂蚁速度为
v - 查询 之前有多少速度 < v
这样就得到所有 (i,j)。但速度范围可能很大,所以需要 离散化。
Python(BIT实现)
import sys
input = sys.stdin.readline
class BIT:
def __init__(self, n):
self.n = n
self.tree = [0]*(n+1)
def update(self, i, v=1):
while i <= self.n:
self.tree[i] += v
i += i & -i
def query(self, i):
s = 0
while i > 0:
s += self.tree[i]
i -= i & -i
return s
n = int(input())
v = [int(input()) for _ in range(n)]
# 1. 离散化
vals = sorted(set(v))
rank = {x:i+1 for i,x in enumerate(vals)}
# 2. BIT
bit = BIT(len(vals))
ans = 0
# 3. 从左到右扫描
for x in v:
r = rank[x]
# 统计之前速度 < x
ans += bit.query(r-1)
# 当前速度加入
bit.update(r)
print(ans)运行逻辑示例
输入:
1 5 10 7 6扫描过程:
| 当前蚂蚁 | 速度 | 之前更慢的数量 | 累计 |
|---|---|---|---|
| 1 | 1 | 0 | 0 |
| 2 | 5 | 1 | 1 |
| 3 | 10 | 2 | 3 |
| 4 | 7 | 2 | 5 |
| 5 | 6 | 2 | 7 |
结果:
7复杂度
- 离散化:O(N log N)
- BIT操作:O(N log N)
总体:O(N log N),适合 N=10^5。
课堂讲解建议
这题特别适合讲三种方法的对比:
1️⃣ 暴力 O(N^2)
2️⃣ 归并排序统计逆序对 O(N log N)
3️⃣ 树状数组统计顺序对 O(N log N)
import sys
def merge_sort(arr, temp, left, right):
if left >= right:
return 0
mid = (left + right) // 2
inv_count = merge_sort(arr, temp, left, mid) + merge_sort(arr, temp, mid + 1, right)
# 归并过程,同时计算逆序数
i, j, k = left, mid + 1, left
while i <= mid and j <= right:
if arr[i] >= arr[j]: # 注意这里是 >=,保证稳定性
temp[k] = arr[i]
i += 1
else:
temp[k] = arr[j]
inv_count += (mid - i + 1) # 统计逆序对
j += 1
k += 1
while i <= mid:
temp[k] = arr[i]
i += 1
k += 1
while j <= right:
temp[k] = arr[j]
j += 1
k += 1
# 拷贝回原数组
for i in range(left, right + 1):
arr[i] = temp[i]
return inv_count
if __name__ == "__main__":
n = int(sys.stdin.readline().strip())
arr = [int(sys.stdin.readline().strip()) for _ in range(n)]
temp = [0] * n
result = merge_sort(arr, temp, 0, n - 1)
print(result)主要优化点:
- 索引传递优化:避免创建子列表,改为在原数组上进行归并排序,提高空间效率。
- 减少
extend操作:直接在temp中合并排序,最后一次性拷贝回arr,减少内存拷贝开销。 - 提高稳定性:使用
arr[i] <= arr[j],确保排序稳定。 - 使用
sys.stdin.readline():加速大规模输入读取,提升整体运行效率。
时间复杂度: temp 作为辅助数组)
import sys
from collections import Counter
sys.setrecursionlimit(200000)
def merge_count(arr):
# 归并排序,同时计算逆序数
n = len(arr)
if n <= 1:
return arr, 0
mid = n // 2
left, inv_left = merge_count(arr[:mid])
right, inv_right = merge_count(arr[mid:])
merged = []
i = j = 0
inv = inv_left + inv_right
while i < len(left) and j < len(right):
# 如果左边元素<=右边元素,不构成逆序对(注意:相等情况不算)
if left[i] <= right[j]:
merged.append(left[i])
i += 1
else:
# left[i] > right[j]构成逆序对,左边剩余的元素都大于right[j]
merged.append(right[j])
inv += len(left) - i
j += 1
merged.extend(left[i:])
merged.extend(right[j:])
return merged, inv
def main():
input_data = sys.stdin.read().split()
if not input_data:
return
n = int(input_data[0])
speeds = list(map(int, input_data[1:]))
total_pairs = n * (n - 1) // 2
# 统计相等对数(任意两个相同速度的蚂蚁不会发生赶超)
cnt = Counter(speeds)
equal_pairs = sum(v * (v - 1) // 2 for v in cnt.values())
# 计算传统逆序数:统计满足a[i] > a[j]的(i,j)
_, inv = merge_count(speeds)
# 根据分析,赶超事件的对数为:
result = total_pairs - equal_pairs - inv
print(result)
if __name__ == '__main__':
main()#23n2300011505(12号娱乐选手)
def merge_sort(l):
if len(l) <= 1:
return l, 0
mid = len(l) // 2
left, left_count = merge_sort(l[:mid])
right, right_count = merge_sort(l[mid:])
l, merge_count = merge(left, right)
return l, left_count + right_count + merge_count
def merge(left, right):
merged = []
left_index, right_index = 0, 0
count = 0
while left_index < len(left) and right_index < len(right):
if left[left_index] >= right[right_index]:
merged.append(left[left_index])
left_index += 1
else:
merged.append(right[right_index])
right_index += 1
count += len(left) - left_index
merged += left[left_index:]+right[right_index:]
return merged, count
n = int(input())
l = []
for i in range(n):
l.append(int(input()))
l, ans = merge_sort(l)
print(ans)直觉解(bisect),复杂度:O(N^2)
from bisect import bisect_left
n=int(input())
v=[]
ans=0
for i in range(n):
p=int(input())
index=bisect_left(v,p)
v.insert(index,p)
ans+=index
print(ans)bisect_left 的作用
index = bisect_left(v, p)含义:在有序数组
v中找到 p 应该插入的位置并且这个位置满足:
v[0:index] < p v[index:] >= p因此:
index = 已有元素中 < p 的个数这正好就是需要统计的数量。
#23n2300011042(Apocalypse)
import bisect
while True:
try:
n = int(input())
ans = 0
l = []
for _ in range(n):
t = int(input())
dx = len(l) - (bisect.bisect_right(l, -t))
ans += dx
bisect.insort_right(l, -t)
print(ans)
input()
except EOFError:
break