# attention/env.py import numpy as np from torch.utils.data import Dataset
classTSPDataset(Dataset): def__init__(self, file_path: str): # read the data form .txt withopen(file_path, "r") as file: points_list = list() tour_list = list() for line in file: line = line.strip() split_line = line.split(" output ") # parse points points = split_line[0].split(" ") points = np.array([[float(points[i]), float(points[i + 1])] for i inrange(0, len(points), 2)]) points_list.append(points) # parse tour tour = split_line[1].split(" ") tour = np.array([int(t) for t in tour]) tour -= 1# convert to 0-based index tour_list.append(tour) self.points = np.array(points_list) self.tours = np.array(tour_list)
# attention/env.py from dataclasses import dataclass import torch from torch import Tensor from torch.utils.data import DataLoader from ml4co_kit import BaseEnv
@dataclass classStepState: """ A data class to hold the state of the environment at each decoding step. This makes passing state information to the model cleaner. """ current_node: Tensor = None# Shape: (batch,) tours: Tensor = None# Shape: (batch, time_step) mask: Tensor = None# Shape: (batch, num_nodes)
classAttentionEnv(BaseEnv): def__init__( self, mode: str = "train", train_batch_size: int = 4, val_batch_size: int = 4, train_path: str = None, val_path: str = None, num_workers: int = 4, device: str = "cpu", ): super(AttentionEnv, self).__init__( name="AttentionEnv", mode=mode, train_batch_size=train_batch_size, val_batch_size=val_batch_size, train_path=train_path, val_path=val_path, num_workers=num_workers, device=device ) if mode isnotNone: self.load_data() self.num_nodes = self.train_dataset.points.shape[1] ifself.train_dataset elseNone self.points = None self.batch_size = None # These will be managed during reset and step self.current_node = None self.tours = None self.mask = None
defreset(self, points: Tensor): """ Resets the environment for a new rollout. """ self.points = points.to(self.device) # Shape: (batch_size, num_nodes, 2) self.batch_size = self.points.size(0) self.current_node = None self.tours = torch.zeros((self.batch_size, 0), dtype=torch.long, device=self.device) self.mask = torch.ones((self.batch_size, self.num_nodes), device=self.device) state_step = StepState(current_node=self.current_node, tours=self.tours, mask=self.mask) return state_step, None, None# Initial state, no reward, not done
defstep(self, selected_node: Tensor): """ Updates the environment state based on the selected node. Args: selected_node (Tensor): The node selected by the policy model. Shape: (batch_size,). Returns: A tuple containing: - state (StepState): The new state of the environment. - reward (Tensor or None): The final reward (negative tour length) if done, else None. - done (bool): A boolean indicating if the tour is complete. """ self.current_node = selected_node self.tours = torch.cat([self.tours, self.current_node.unsqueeze(-1)], dim=1) self.mask.scatter_(dim=1, index=self.current_node.unsqueeze(-1), value=0) # Mark the selected node as visited
done = (self.tours.size(1) == self.num_nodes) reward = -self.evaluate() if done elseNone# Negative tour length as reward state_step = StepState(current_node=self.current_node, tours=self.tours, mask=self.mask) return state_step, reward, done
defevaluate(self): """ Calculates the total length of the generated tours.
Returns: Tensor: The total length for each tour in the batch. Shape: (batch_size,). """ # Gather coordinates in tour order. # self.tours.shape: (batch_size, num_nodes) tour_coords = torch.gather(input=self.points, dim=1, index=self.tours.unsqueeze(-1).expand(-1, -1, 2)) # Shape: (batch_size, num_nodes, 2)
# Calculate distances between consecutive nodes, including returning to the start rolled_coords = tour_coords.roll(dims=1, shifts=-1) segment_lengths = torch.norm(tour_coords - rolled_coords, dim=2)
return segment_lengths.sum(dim=1)
DataLoader 部分的函数与之前一致,在此从略。
StepState 表示 environment 的 state,包装成了 data class,相当于 C++ 中的结构体。其中包含 current_node 当前选中的节点、tours 当前构建的部分路径、mask 当前掩码(已经访问过的节点为 0,其余节点为 1)。
# attention/policy.py from dataclasses import dataclass import torch from torch import Tensor, nn from torch.distributions import Categorical from .env import AttentionEnv from .encoder import AttentionEncoder from .decoder import AttentionDecoder
@dataclass classStepState: """ A data class to hold the state of the environment at each decoding step. This makes passing state information to the model cleaner. """ current_node: Tensor = None# Shape: (batch,) tours: Tensor = None# Shape: (batch, time_step) mask: Tensor = None# Shape: (batch, num_nodes)
defforward(self, points: Tensor, mode: str = "sampling"): """ Performs a full rollout to generate a tour for a batch of TSP instances.
Args: points (torch.Tensor): Node coordinates for the batch. Shape: (batch_size, num_nodes, 2). mode (str): 'sampling' for stochastic rollout or 'greedy' for deterministic.
Returns: A tuple containing: - reward (torch.Tensor): Reward for each instance in the batch. Shape: (batch_size,). - sum_log_probs (torch.Tensor): Sum of action log probabilities. Shape: (batch_size,). - tour (torch.Tensor): The decoded tour for each instance. Shape: (batch_size, num_nodes + 1). """ batch_size = points.size(0)
tour = state.tours # Shape: (batch_size, num_nodes) start_node = tour[:, 0].unsqueeze(1) # Shape: (batch_size, 1) tour = torch.cat([tour, start_node], dim=1) # Append the start node to the end of the tour
# attention/model.py import copy import numpy as np import torch from torch import Tensor from ml4co_kit import BaseModel, TSPSolver from .env import AttentionEnv from .encoder import AttentionEncoder from .decoder import AttentionDecoder from .policy import AttentionPolicy
classAttentionModel(BaseModel): def__init__( self, env: AttentionEnv, encoder: AttentionEncoder, decoder: AttentionDecoder, lr_scheduler: str = "cosine-decay", learning_rate: float = 2e-4, weight_decay: float = 1e-4, ): super(AttentionModel, self).__init__( env=env, # The main model to be trained model=AttentionPolicy( env=env, encoder=encoder, decoder=decoder, ), lr_scheduler=lr_scheduler, learning_rate=learning_rate, weight_decay=weight_decay ) self.to(self.env.device)
# Create a separate baseline model baseline_encoder = copy.deepcopy(encoder) baseline_decoder = copy.deepcopy(decoder) self.baseline_model = AttentionPolicy( env=env, encoder=baseline_encoder, decoder=baseline_decoder, ).to(self.env.device) self.baseline_model.eval() # Set to evaluation mode permanently self.update_baseline() # Initialize baseline with policy weights
# Store validation metrics self.val_metrics = []
defupdate_baseline(self): """Copies the weights from the policy model to the baseline model.""" self.baseline_model.load_state_dict(self.model.state_dict())
defshared_step(self, batch, batch_idx, phase): """ Shared step for training, validation, and testing. """ self.env.mode = phase # unpack batch data points, ref_tours = batch # points: (batch_size, num_nodes, 2) # ref_tours: (batch_size, num_nodes + 1) if phase == "train": # --- 1. Policy Rollout (stochastic) --- # Gradients are tracked for this rollout. self.model.train() # Ensure model is in training mode reward, sum_log_probs, tours = self.model(points, mode='sampling') policy_cost = -reward # Reward is negative tour length elif phase == "val": with torch.no_grad(): self.model.eval() # Set model to evaluation mode # Evaluate the policy model reward, sum_log_probs, tours = self.model(points, mode='greedy') policy_cost = -reward
# --- 2. Baseline Rollout (greedy) --- # No gradients are needed for the baseline. with torch.no_grad(): reward, _, baseline_tours = self.baseline_model(points, mode='greedy') baseline_cost = -reward # Reward is negative tour length
# --- 3. Calculate REINFORCE Loss --- # The advantage is the gap between the sampled solution and the greedy baseline. advantage = policy_cost - baseline_cost # The loss is the mean of advantage-weighted negative log-probabilities. loss = (advantage * sum_log_probs).mean()
# --- 4. Logging --- metrics = {f"{phase}/loss": loss} # print(f"loss: {loss.item()}") if phase == "val": metrics.update({"val/costs_avg": costs_avg, "val/gap_avg": gap_avg, "val/baseline_costs_avg": baseline_costs_avg}) self.val_metrics.append(metrics) for k, v in metrics.items(): self.log(k, float(v), prog_bar=True, on_epoch=True, sync_dist=True) # return return loss if phase == "train"else metrics
defon_validation_epoch_end(self): # Aggregate the costs from all validation batches avg_policy_cost = np.array([x['val/costs_avg'] for x inself.val_metrics]).mean() avg_baseline_cost = np.array([x['val/baseline_costs_avg'] for x inself.val_metrics]).mean() # Baseline Update if avg_policy_cost < avg_baseline_cost: self.update_baseline() self.val_metrics.clear() # Clear the metrics for the next epoch
defevaluate(self, x: Tensor, tours: Tensor, ref_tours: Tensor): """ Evaluate the model's performance on a given set of tours. Args: x: (batch_size, num_nodes, 2) tensor representing node coordinates. tours: (batch_size, num_nodes+1) tensor representing predicted tours. ref_tours: (batch_size, num_nodes+1) tensor representing reference tours. Returns: costs_avg: Average cost of the predicted tours. ref_costs_avg: Average cost of the reference tours. gap_avg: Average gap between predicted and reference tours. gap_std: Standard deviation of the gap. """ x = x.cpu().numpy() tours = tours.cpu().numpy() ref_tours = ref_tours.cpu().numpy()
# attention/train.py from .env import AttentionEnv from .encoder import AttentionEncoder from .decoder import AttentionDecoder from .model import AttentionModel from ml4co_kit import Trainer