# Create separate linear layers for Query, Key, and Value self.fc_q = nn.Linear(embed_dim, embed_dim) self.fc_k = nn.Linear(embed_dim, embed_dim) self.fc_v = nn.Linear(embed_dim, embed_dim)
# Create the final fully connected output layer self.fc_out = nn.Linear(embed_dim, embed_dim)
defforward(self, x: Tensor): """ Forward pass for the Multi-Head Self-Attention layer. Args: x (torch.Tensor): Input tensor of shape (batch_size, num_nodes, embed_dim). Returns: torch.Tensor: The output tensor of shape (batch_size, num_nodes, embed_dim). """ batch_size = x.shape[0]
# 1. Project input into Q, K, V using separate linear layers Q = self.fc_q(x) # Shape: (batch_size, num_nodes, embed_dim) K = self.fc_k(x) # Shape: (batch_size, num_nodes, embed_dim) V = self.fc_v(x) # Shape: (batch_size, num_nodes, embed_dim)
# 2. Split the embed_dim into num_heads and head_dim Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # Shape: (batch_size, num_heads, num_nodes, head_dim) K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # Shape: (batch_size, num_heads, num_nodes, head_dim) V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # Shape: (batch_size, num_heads, num_nodes, head_dim)
# 3. Calculate scaled dot-product attention # Calculate the dot product of Q and K attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # Shape: (batch_size, num_heads, num_nodes, num_nodes) # Scale the attention scores scaled_attn_scores = attn_scores / math.sqrt(self.head_dim) # Apply softmax to get the attention weights attn_weights = F.softmax(scaled_attn_scores, dim=-1) # Shape: (batch_size, num_heads, num_nodes, num_nodes) # Multiply the weights by V to get the context vector context = torch.matmul(attn_weights, V) # Shape: (batch_size, num_heads, num_nodes, head_dim)
# 4. Concatenate the attention heads' outputs # First, transpose to bring num_nodes and num_heads dimensions together context = context.transpose(1, 2).contiguous() # Shape: (batch_size, num_nodes, num_heads, head_dim) # Then, reshape to combine the last two dimensions context = context.view(batch_size, -1, self.embed_dim) # Shape: (batch_size, num_nodes, embed_dim)
# 5. Pass the concatenated context vector through the final linear layer output = self.fc_out(context) # Shape: (batch_size, num_nodes, embed_dim)
# Create the first linear layer self.fc1 = nn.Linear(embed_dim, hidden_dim) # Create the second linear layer self.fc2 = nn.Linear(hidden_dim, embed_dim)
defforward(self, x: Tensor): """ Forward pass for the Feed Forward Neural Network layer. Args: x (torch.Tensor): Input tensor of shape (batch_size, num_nodes, embed_dim). Returns: torch.Tensor: The output tensor of shape (batch_size, num_nodes, embed_dim). """ # Apply the first linear layer followed by ReLU activation x = F.relu(self.fc1(x)) # Shape: (batch_size, num_nodes, hidden_dim) # Apply the second linear layer output = self.fc2(x) # Shape: (batch_size, num_nodes, embed_dim)
# attention/encoder.py from torch import Tensor, nn from torch import Tensor, nn from .attn_layer import MultiHeadSelfAttention from .ff_layer import FeedForward
classAttentionLayer(nn.Module): """ A single Attention Layer that follows the structure from the image. It consists of a Multi-Head Attention sublayer and a Feed-Forward sublayer. Each sublayer is followed by a skip connection and Batch Normalization. """ def__init__(self, embed_dim: int, num_heads: int, hidden_dim: int): super(AttentionLayer, self).__init__()
# Stack of identical Attention Layers self.layers = nn.ModuleList([ AttentionLayer(embed_dim, num_heads, hidden_dim) for _ inrange(num_layers) ])
defforward(self, x: Tensor): """ Forward pass for the Encoder. Args: x (torch.Tensor): Coordinates of nodes with shape (batch_size, num_nodes, 2). Returns: torch.Tensor: The output tensor of shape (batch_size, num_nodes, embed_dim). """ # Embed the input coordinates x = self.embed(x) # Shape: (batch_size, num_nodes, embed_dim)
# Pass through multiple attention layers for layer inself.layers: x = layer(x) # Shape: (batch_size, num_nodes, embed_dim)
return x # Shape: (batch_size, num_nodes, embed_dim)
# attention/decoder.py import math import torch from torch import Tensor, nn import torch.nn.functional as F
classMultiHeadMaskedCrossAttention(nn.Module): """ Implements a Multi-Head Cross-Attention layer with masking.
This layer is designed for a decoder that needs to attend to the output of an encoder. It takes a single context vector as the query source and a sequence of encoder outputs as the key and value source. It also supports masking to prevent attention to nodes that have already been visited in TSP. """ def__init__(self, embed_dim: int, num_heads: int): super(MultiHeadMaskedCrossAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads
# Linear layers for Query, Key, Value, and the final output projection self.fc_q = nn.Linear(embed_dim, embed_dim) self.fc_k = nn.Linear(embed_dim, embed_dim) self.fc_v = nn.Linear(embed_dim, embed_dim) self.fc_out = nn.Linear(embed_dim, embed_dim)
defforward(self, context_query: Tensor, encoder_outputs: Tensor, mask: Tensor = None): """ Forward pass for the Multi-Head Masked Cross-Attention layer.
Args: context_query (torch.Tensor): The query tensor, typically derived from the decoder's state. Shape: (batch_size, 1, embed_dim). encoder_outputs (torch.Tensor): The key and value tensor, typically the output from the encoder. Shape: (batch_size, num_nodes, embed_dim). mask (torch.Tensor, optional): A boolean or 0/1 tensor to mask out certain keys. A value of 0 indicates the position should be masked. Shape: (batch_size, num_nodes).
# 1. Project Q from the context query and K, V from the encoder outputs. Q = self.fc_q(context_query) # Shape: (batch_size, 1, embed_dim) K = self.fc_k(encoder_outputs) # Shape: (batch_size, num_nodes, embed_dim) V = self.fc_v(encoder_outputs) # Shape: (batch_size, num_nodes, embed_dim)
# 4. Apply the mask before the softmax step. if mask isnotNone: # Reshape mask for broadcasting: (batch_size, num_nodes) -> (batch_size, 1, 1, num_nodes) mask_reshaped = mask.unsqueeze(1).unsqueeze(2) # Fill masked positions (where mask is 0) with a very small number. attn_scores = attn_scores.masked_fill(mask_reshaped == 0, -1e9)
# attention/decoder.py classAttentionDecoder(nn.Module): """ Implements the Decoder for the Attention Model.
At each step, it creates a context embedding based on the graph, the first node, and the previously visited node. It then uses two attention mechanisms: 1. A multi-head "glimpse" to refine the context. 2. A single-head mechanism with clipping to calculate the final output probabilities. """ def__init__(self, embed_dim: int, num_heads: int, clip_value: float = 10.0): super(AttentionDecoder, self).__init__()
# Learned placeholders for the first and last nodes at the initial step (t=1) self.v_first_placeholder = nn.Parameter(torch.randn(1, 1, embed_dim)) self.v_last_placeholder = nn.Parameter(torch.randn(1, 1, embed_dim))
# Projection layer for the concatenated context vector self.context_projection = nn.Linear(3 * embed_dim, embed_dim, bias=False)
# The first attention mechanism: a multi-head "glimpse". self.glimpse_attention = MultiHeadMaskedCrossAttention(embed_dim, num_heads)
# Layers for the final single-head attention mechanism to compute probabilities. self.final_q_projection = nn.Linear(embed_dim, embed_dim, bias=False) self.final_k_projection = nn.Linear(embed_dim, embed_dim, bias=False)
defforward(self, encoder_outputs: Tensor, partial_tour: Tensor, mask: Tensor): """ Performs a single decoding step.
Args: encoder_outputs (torch.Tensor): The final node embeddings from the encoder. Shape: (batch_size, num_nodes, embed_dim). partial_tour (torch.Tensor): A tensor of node indices for the current partial tours. Shape: (batch_size, current_tour_length). mask (torch.Tensor): A tensor indicating which nodes are available to be visited. Shape: (batch_size, num_nodes).
Returns: log_probs (torch.Tensor): The log-probabilities for selecting each node as the next step. Shape: (batch_size, num_nodes). """ batch_size = encoder_outputs.shape[0]
# Step 1: Construct the Context Embedding for the entire batch graph_embedding = encoder_outputs.mean(dim=1, keepdim=True) # Shape: (batch_size, 1, embed_dim)
if partial_tour.size(1) == 0: # If this is the first step (t=1) for all instances # Use learned placeholders first_node_emb = self.v_first_placeholder.expand(batch_size, -1, -1) # Shape: (batch_size, 1, embed_dim) last_node_emb = self.v_last_placeholder.expand(batch_size, -1, -1) # Shape: (batch_size, 1, embed_dim) else: # Get indices of the first and last nodes for each instance in the batch first_node_indices = partial_tour[:, 0] # Shape: (batch_size,) last_node_indices = partial_tour[:, -1] # Shape: (batch_size,)
# Concatenate the three components to form the raw context raw_context = torch.cat([graph_embedding, first_node_emb, last_node_emb], dim=2) # Shape: (batch_size, 1, 3 * embed_dim)
# Project the context to create the initial query context_query = self.context_projection(raw_context) # Shape: (batch_size, 1, embed_dim)