# LayerNorm for node and edge self.bn_x = nn.LayerNorm(hidden_dim) self.bn_e = nn.LayerNorm(hidden_dim)
defforward(self, x: Tensor, e: Tensor, edge_index: Tensor): """ Args: x: (V, H) Node features; e: (E, H) Edge features edge_index: (2, E) Tensor with edges representing connections from source to target nodes. Returns: Updated x and e after one layer of GNN. """ # Deconstruct edge_index src, dest = edge_index # shape: (E, )
# --- Node Update --- w2_x_src = self.W2(x[src]) # shape: (E, H) messages = e * w2_x_src # shape: (E, H) aggr_messages = torch.zeros_like(x) # shape: (V, H) # index_add_ adds the 'messages' to 'aggr_messages' at indices specified by 'dest' aggr_messages.index_add_(0, dest, messages) # shape: (V, H) x_new = x + F.relu(self.bn_x(self.W1(x) + aggr_messages)) # shape: (V, H)
这句提取出了边的源顶点列表 src 和目的顶点列表 dest。它们都是 E 维向量,第 i 条边的源顶点索引为 src[i],目的顶点索引为 dest[i]。
1
w2_x_src = self.W2(x[src]) # shape: (E, H)
x[src] 是 Tensor 的高级索引操作。x 的第 i 行是第 i 个节点的特征向量。x[src] 从顶点特征表 x 中找出 src 中每个顶点的特征,并汇聚成一个形状为 (E, H) 的张量,其中第 i 行是第 i 条边的源顶点的特征向量。然后对每个顶点的特征应用 self.W2 的线性变换,得到 w2_x_src。
1
messages = e * w2_x_src # shape: (E, H)
这一步将每条边的特征向量及其源顶点的特征向量进行逐元素相乘,最终得到的 messages 有 E 行,第 i 行是融合了第 i 条边及其源顶点特征的特征向量。
classOutLayer(nn.Module): def__init__(self, hidden_dim: int, layer_num: int): """ Args: hidden_dim: The dimension of the input edge features. layer_num: The number of layers in the MLP. """ super(OutLayer, self).__init__() mlp_layers = [] if layer_num == 1: mlp_layers.append(nn.Linear(hidden_dim, 2)) else: mlp_layers.append(nn.Linear(hidden_dim, hidden_dim)) mlp_layers.append(nn.ReLU()) for _ inrange(layer_num - 2): mlp_layers.append(nn.Linear(hidden_dim, hidden_dim)) mlp_layers.append(nn.ReLU()) mlp_layers.append(nn.Linear(hidden_dim, 2)) self.mlp = nn.Sequential(*mlp_layers)
defforward(self, e_final: Tensor): """ Args: e_final: (E, H) Final edge features Returns: prob: (E, 2) Probability of each edge being connected and not connected to the TSP tour. """ prob = self.mlp(e_final) # shape: (E, 2) return prob
注意,最后输出的预测向量维度是 (E, 2),分别表示每条边“在”和“不在”最终 tour 中的概率。但这里我们还没有做 softmax 归一化,因为在计算 loss 时会包含 softmax 过程。
# gnn/encoder.py import torch from torch import Tensor, nn from .embedder import Embedder from .gnn_layer import GNNLayer from .out_layer import OutLayer import torch.nn.functional as F
[1] C. K. Joshi, T. Laurent, and X. Bresson, “An efficient graph convolutional network technique for the travelling salesman problem,” arXiv preprint arXiv:1906.01227, 2019.