Skip to content

M29740:神经网络

topological order, http://cs101.openjudge.cn/practice/29740/

随着深度学习技术的发展,神经网络已广泛应用于图像识别、自然语言处理、推荐系统等多个领域。近年来,图神经网络(Graph Neural Networks, GNN)作为一种能够处理非欧式结构化数据的深度学习方法,越来越受到研究者关注。小P同学最近在学习图神经网络的基础知识时,尝试设计了一个简化的“前馈图神经网络”模型,并希望借助编程来验证其推理过程是否合理。

在小P的模型中,一个图神经网络被抽象成一个有向图,图中的节点代表神经元(或计算单元),边则表示信息的传递通道。任意两个节点之间逻辑上至多视为一条有向边。下图是一个神经元节点的示意图:

img

神经元(编号为 i)

图中,X1 ~ X3 表示来自上游节点的信息输入,Y1 ~ Y2 是向下游节点的输出通道,Ci 表示神经元当前的激活值(activation),而 Ui 是该节点的偏置项(bias),可视为其内在的阈值参数。

神经元按一定的顺序排列,构成整个图神经网络。在小P的模型之中,神经元近似于分层摆放,可以近似视作输入层、若干隐藏层与输出层;信息只能沿有向边由前向后传播。下图是一个最简单的三层神经网络的例子。

img

现实中若某些突触损坏,会产生“闭环反馈”使整个网络失效。若这幅有向图 存在回路(哪怕仅为自环),称为神经坏死,网络将彻底无法工作。

小P定义的神经元更新规则如下(令 n 表示总神经元数量):

img

其中:

  • Wji 表示从节点 j 到节点 i 的边权重(可以为负值)
  • Cj 是上游神经元的当前激活值
  • Ui 是当前神经元的偏置项
  • 若 Ci > 0,则该神经元处于激活(active)状态,下一时刻将以强度Ci 向下一层传递信号;否则,它将保持静默(inactive)

当输入层(入度为0的节点)的神经元激活后,该网络模型将以类似于图神经网络的方式逐层传播信号,最终在输出层(出度为0的节点)产生结果。

输入

输入文件第一行是两个整数 n(1≤n≤100)和 p(0 ≤ p ≤ n(n-1))。

接下来 n 行,每行 2 个整数,第 i+1 行是神经元 i 最初状态和其阈值(Ui),非输入层的神经元开始时状态必然为 0。

再下面 p 行,每行有两个整数 i,j 及一个整数 Wij,表示连接神经元 i,j 的边权值为 Wij。(i,j) 二元组可重复出现,若重复则视作一条边,将权重累加即可。

输出

输出文件包含若干行,每行有 2 个整数,分别对应一个神经元的编号,及其最后的状态,2 个整数间以空格分隔。仅输出最后状态大于 0 的输出层神经元状态,并且按照编号由小到大顺序输出。

若图中存在任何有向环(神经坏死),或所有输出层神经元最后状态均 ≤0,则输出 NULL。

样例输入

样例1:
5 6
1 0
1 0
0 1
0 1
0 1
1 3 1
1 4 1
1 5 1
2 3 1
2 4 1
2 5 1

样例2:
2 2
1 0
0 0
1 2 1
2 1 1

样例输出

样例1:
3 1
4 1
5 1

样例2:
NULL

提示

拓扑排序,AI

输入层神经元状态不需减去其偏置项 输入层可以激活也可以不激活,C <= 0 是不激活输入层

C/C++,变量类型需要考虑long long以防止溢出

来源

TA-xjk

python
import sys

# 增加递归深度限制(虽然本解法使用迭代,但作为通用的各种Python算法习惯)
sys.setrecursionlimit(5000)


def solve():
    # 使用 sys.stdin.read 一次性读取所有输入,提高效率
    try:
        input_data = sys.stdin.read().split()
    except Exception:
        return

    if not input_data:
        return

    iterator = iter(input_data)

    try:
        n = int(next(iterator))
        p = int(next(iterator))
    except StopIteration:
        return

    # 初始化数组,使用下标 1 到 n
    # c: 神经元状态
    # u: 阈值
    c = [0] * (n + 1)
    u = [0] * (n + 1)

    # 读取神经元初始信息
    for i in range(1, n + 1):
        c[i] = int(next(iterator))
        u[i] = int(next(iterator))

    # 处理边信息
    # 使用字典来处理重边情况(题目提示:元组可重复出现,权重累加)
    edge_weights = {}

    # 记录入度和出度
    in_degree = [0] * (n + 1)
    out_degree = [0] * (n + 1)

    for _ in range(p):
        src = int(next(iterator))
        dst = int(next(iterator))
        w = int(next(iterator))

        # 累加权重
        if (src, dst) in edge_weights:
            edge_weights[(src, dst)] += w
        else:
            edge_weights[(src, dst)] = w

    # 构建邻接表
    adj = [[] for _ in range(n + 1)]
    for (src, dst), w in edge_weights.items():
        adj[src].append((dst, w))
        in_degree[dst] += 1
        out_degree[src] += 1

    # 拓扑排序队列初始化
    queue = []

    # 初始化状态修正
    for i in range(1, n + 1):
        if in_degree[i] == 0:
            # 输入层:入度为0。
            # 它们直接使用读入的 c[i] 作为初始激活值,不需要减 u[i]
            queue.append(i)
        else:
            # 非输入层:
            # 公式是 C_i = sum(W * C_prev) - U_i
            # 我们将 C_i 初始化为 -U_i,之后累加 sum(W * C_prev) 即可
            c[i] -= u[i]

    # 执行拓扑排序
    topo_order = []
    head = 0  # 模拟队列指针

    while head < len(queue):
        curr = queue[head]
        head += 1
        topo_order.append(curr)

        # 信号传播
        # 规则:只有当 Ci > 0 时,才向后传播信号
        if c[curr] > 0:
            for neighbor, weight in adj[curr]:
                c[neighbor] += weight * c[curr]

        # 拓扑逻辑(无论是否传播信号,都要处理边的移除)
        for neighbor, weight in adj[curr]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

    # 环检测
    # 如果拓扑序列长度小于 N,说明有环(神经坏死)
    if len(topo_order) < n:
        print("NULL")
        return

    # 输出结果
    has_output = False
    # 按照编号从小到大遍历(自然序)
    for i in range(1, n + 1):
        # 输出层定义:出度为 0
        if out_degree[i] == 0:
            # 只有状态 > 0 才输出
            if c[i] > 0:
                print(f"{i} {c[i]}")
                has_output = True

    # 如果没有合法的输出,打印 NULL
    if not has_output:
        print("NULL")


if __name__ == "__main__":
    solve()