Skip to content

20018: 蚂蚁王国的越野跑

merge sort, binary indexed tree, binary search, http://cs101.openjudge.cn/practice/20018

为了促进蚂蚁家族身体健康,提高蚁族健身意识,蚂蚁王国举行了越野跑。假设越野跑共有N个蚂蚁参加,在一条笔直的道路上进行。N个蚂蚁在起点处站成一列,相邻两个蚂蚁之间保持一定的间距。比赛开始后,N个蚂蚁同时沿着道路向相同的方向跑去。换句话说,这N个蚂蚁可以看作x轴上的N个点,在比赛开始后,它们同时向X轴正方向移动。假设越野跑的距离足够远,这N个蚂蚁的速度有的不相同有的相同且保持匀速运动,那么会有多少对参赛者之间发生“赶超”的事件呢?此题结果比较大,需要定义long long类型。请看备注。

img

输入

第一行1个整数N。 第2… N +1行:N 个非负整数,按从前到后的顺序给出每个蚂蚁的跑步速度。对于50%的数据,2<=N<=1000。对于100%的数据,2<=N<=100000。

输出

一个整数,表示有多少对参赛者之间发生赶超事件。

样例输入

sample1 input:
5
1
5
10
7
6

sample2 input:
5
1
5
5
7
6

样例输出

sample1 output:
7

sample2 output:
8

提示

我们把这5个蚂蚁依次编号为A,B,C,D,E,假设速度分别为1,5,5,7,6。在跑步过程中:B,C,D,E均会超过A,因为他们的速度都比A快;D,E都会超过B,C,因为他们的速度都比B,C快;D,E之间不会发生赶超,因为速度快的起跑时就在前边;B,C之间不会发生赶超,因为速度一样,在前面的就一直在前面。

考虑归并排序的思想。

此题结果比较大,需要定义long long类型,其输出格式为printf("%lld",x); long long,有符号 64位整数,所占8个字节(Byte) -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807

出题人太随意了。题目中的样例数据是 2 组样例数据。

这题本质上是 统计逆序对

关键观察

蚂蚁一开始按位置从前到后排好: 若前面的蚂蚁 i 速度 v[i] 小于 后面的蚂蚁 j 速度 v[j],那么 j 一定会追上 i

因此需要统计:i < jv[i] < v[j] 的对数。

直觉解(bisect)

复杂度:O(N^2)**

python
from bisect import bisect_left
n=int(input())
v=[]
ans=0
for i in range(n):
    p=int(input())
    index=bisect_left(v,p)
    v.insert(index,p)
    ans+=index
print(ans)

bisect_left 的作用

index = bisect_left(v, p)

含义:在有序数组 v 中找到 p 应该插入的位置

并且这个位置满足:

v[0:index] < p
v[index:] >= p

因此:

index = 已有元素中 < p 的个数

这正好就是需要统计的数量。

分治解(归并统计)

归并排序统计逆序对通常统计的是:i < jv[i] > v[j]

所以只要把速度 取负数,就能直接用逆序对模板。

时间复杂度:O(N log N),适合 N ≤ 100000。


Python实现(归并排序统计)

python
import sys
sys.setrecursionlimit(1000000)

def merge_sort(arr):
    n = len(arr)
    if n <= 1:
        return arr, 0

    mid = n // 2
    left, cnt1 = merge_sort(arr[:mid])
    right, cnt2 = merge_sort(arr[mid:])

    i = j = 0
    merged = []
    cnt = cnt1 + cnt2

    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            merged.append(left[i])
            i += 1
        else:
            merged.append(right[j])
            cnt += len(left) - i
            j += 1

    merged.extend(left[i:])
    merged.extend(right[j:])

    return merged, cnt


def main():
    input = sys.stdin.readline
    n = int(input())
    v = [int(input()) for _ in range(n)]

    # 转成负数,把 v[i] < v[j] 转成逆序对
    arr = [-x for x in v]

    _, ans = merge_sort(arr)
    print(ans)


if __name__ == "__main__":
    main()

思维图(非常关键)

原序列:

位置: 1 2 3 4 5
速度: 1 5 10 7 6

追赶发生条件:

前面慢,后面快
v[i] < v[j]

例如:

1 < 5
1 < 10
1 < 7
1 < 6
...

统计这种对数即可。


为什么题目提示归并排序

因为:

  • N 最大 100000
  • 暴力 O(N^2) 会 超时
  • 归并统计逆序对是经典 O(N log N)

数据结构解(BIT)

下面给出 BIT(树状数组)版本。这个题其实非常适合讲 树状数组统计顺序对

我们要统计:i < jv[i] < v[j]

思路是 从左到右扫描

  • 已经扫描过的蚂蚁速度放进树状数组
  • 当前蚂蚁速度为 v
  • 查询 之前有多少速度 < v

这样就得到所有 (i,j)。但速度范围可能很大,所以需要 离散化


Python(BIT实现)

python
import sys
input = sys.stdin.readline

class BIT:
    def __init__(self, n):
        self.n = n
        self.tree = [0]*(n+1)

    def update(self, i, v=1):
        while i <= self.n:
            self.tree[i] += v
            i += i & -i

    def query(self, i):
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= i & -i
        return s


n = int(input())
v = [int(input()) for _ in range(n)]

# 1. 离散化
vals = sorted(set(v))
rank = {x:i+1 for i,x in enumerate(vals)}

# 2. BIT
bit = BIT(len(vals))

ans = 0

# 3. 从左到右扫描
for x in v:
    r = rank[x]
    
    # 统计之前速度 < x
    ans += bit.query(r-1)
    
    # 当前速度加入
    bit.update(r)

print(ans)

运行逻辑示例

输入:

1 5 10 7 6

扫描过程:

当前蚂蚁速度之前更慢的数量累计
1100
2511
31023
4725
5627

结果:

7

复杂度

  • 离散化:O(N log N)
  • BIT操作:O(N log N)

总体:O(N log N),适合 N=10^5。


课堂讲解建议

这题特别适合讲三种方法的对比:

1️⃣ 暴力 O(N^2)

2️⃣ 归并排序统计逆序对 O(N log N)

3️⃣ 树状数组统计顺序对 O(N log N)

python
import sys


def merge_sort(arr, temp, left, right):
    if left >= right:
        return 0
    mid = (left + right) // 2
    inv_count = merge_sort(arr, temp, left, mid) + merge_sort(arr, temp, mid + 1, right)

    # 归并过程,同时计算逆序数
    i, j, k = left, mid + 1, left
    while i <= mid and j <= right:
        if arr[i] >= arr[j]:  # 注意这里是 >=,保证稳定性
            temp[k] = arr[i]
            i += 1
        else:
            temp[k] = arr[j]
            inv_count += (mid - i + 1)  # 统计逆序对
            j += 1
        k += 1

    while i <= mid:
        temp[k] = arr[i]
        i += 1
        k += 1
    while j <= right:
        temp[k] = arr[j]
        j += 1
        k += 1

    # 拷贝回原数组
    for i in range(left, right + 1):
        arr[i] = temp[i]

    return inv_count


if __name__ == "__main__":
    n = int(sys.stdin.readline().strip())
    arr = [int(sys.stdin.readline().strip()) for _ in range(n)]
    temp = [0] * n
    result = merge_sort(arr, temp, 0, n - 1)
    print(result)

主要优化点:

  1. 索引传递优化:避免创建子列表,改为在原数组上进行归并排序,提高空间效率。
  2. 减少 extend 操作:直接在 temp 中合并排序,最后一次性拷贝回 arr,减少内存拷贝开销。
  3. 提高稳定性:使用 arr[i] <= arr[j],确保排序稳定。
  4. 使用 sys.stdin.readline():加速大规模输入读取,提升整体运行效率。

时间复杂度: O(NlogN)空间复杂度: O(N) (使用 temp 作为辅助数组)

python
import sys
from collections import Counter

sys.setrecursionlimit(200000)

def merge_count(arr):
    # 归并排序,同时计算逆序数
    n = len(arr)
    if n <= 1:
        return arr, 0
    mid = n // 2
    left, inv_left = merge_count(arr[:mid])
    right, inv_right = merge_count(arr[mid:])
    merged = []
    i = j = 0
    inv = inv_left + inv_right
    while i < len(left) and j < len(right):
        # 如果左边元素<=右边元素,不构成逆序对(注意:相等情况不算)
        if left[i] <= right[j]:
            merged.append(left[i])
            i += 1
        else:
            # left[i] > right[j]构成逆序对,左边剩余的元素都大于right[j]
            merged.append(right[j])
            inv += len(left) - i
            j += 1
    merged.extend(left[i:])
    merged.extend(right[j:])
    return merged, inv

def main():
    input_data = sys.stdin.read().split()
    if not input_data:
        return
    n = int(input_data[0])
    speeds = list(map(int, input_data[1:]))
    total_pairs = n * (n - 1) // 2

    # 统计相等对数(任意两个相同速度的蚂蚁不会发生赶超)
    cnt = Counter(speeds)
    equal_pairs = sum(v * (v - 1) // 2 for v in cnt.values())

    # 计算传统逆序数:统计满足a[i] > a[j]的(i,j)
    _, inv = merge_count(speeds)
    # 根据分析,赶超事件的对数为:
    result = total_pairs - equal_pairs - inv
    print(result)

if __name__ == '__main__':
    main()
python
#23n2300011505(12号娱乐选手)
def merge_sort(l):
    if len(l) <= 1:
        return l, 0
    mid = len(l) // 2
    left, left_count = merge_sort(l[:mid])
    right, right_count = merge_sort(l[mid:])
    l, merge_count = merge(left, right)
    return l, left_count + right_count + merge_count


def merge(left, right):
    merged = []
    left_index, right_index = 0, 0
    count = 0
    while left_index < len(left) and right_index < len(right):
        if left[left_index] >= right[right_index]:
            merged.append(left[left_index])
            left_index += 1
        else:
            merged.append(right[right_index])
            right_index += 1
            count += len(left) - left_index
    merged += left[left_index:]+right[right_index:]
    return merged, count


n = int(input())
l = []
for i in range(n):
    l.append(int(input()))
l, ans = merge_sort(l)
print(ans)

直觉解(bisect),复杂度:O(N^2)

python
from bisect import bisect_left
n=int(input())
v=[]
ans=0
for i in range(n):
    p=int(input())
    index=bisect_left(v,p)
    v.insert(index,p)
    ans+=index
print(ans)

bisect_left 的作用

index = bisect_left(v, p)

含义:在有序数组 v 中找到 p 应该插入的位置

并且这个位置满足:

v[0:index] < p
v[index:] >= p

因此:

index = 已有元素中 < p 的个数

这正好就是需要统计的数量。

python
#23n2300011042(Apocalypse)
import bisect

while True:
    try:
        n = int(input())
        ans = 0
        l = []
        for _ in range(n):
            t = int(input())
            dx = len(l) - (bisect.bisect_right(l, -t))
            ans += dx
            bisect.insort_right(l, -t)
        print(ans)
        input()
    except EOFError:
        break