# gnn/env.py import numpy as np import torch from torch.utils.data import Dataset
classTSPDataset(Dataset): def__init__(self, file_path: str): # read the data form .txt withopen(file_path, "r") as file: points_list = list() tour_list = list() for line in file: line = line.strip() split_line = line.split(" output ") # parse points points = split_line[0].split(" ") points = np.array([[float(points[i]), float(points[i + 1])] for i inrange(0, len(points), 2)]) points_list.append(points) # parse tour tour = split_line[1].split(" ") tour = np.array([int(t) for t in tour]) tour -= 1# convert to 0-based index tour_list.append(tour) self.points = np.array(points_list) self.tours = np.array(tour_list)
defdecode(self, heatmap: Tensor, nodes_num: int, edge_index: Tensor): """ Args: heatmap: (B, E) tensor representing edges being selected nodes_num: int, number of nodes edge_index: (B, 2, E) Tensor with edges representing connections from source to target nodes Returns: tour: (B, V) tensor representing the tour """ # Convert to numpy for processing heatmap = heatmap.cpu().numpy() edge_index = edge_index.cpu().numpy() # Convert heatmap to a dense format batch_size = heatmap.shape[0] nodes_num = heatmap.shape[1] heatmap_dense = np.zeros((batch_size, nodes_num, nodes_num), dtype=np.float32) for idx inrange(batch_size): heatmap_dense[idx] = np_sparse_to_dense( nodes_num=nodes_num, edge_index=edge_index[idx], edge_attr=heatmap[idx] ) # Convert into a real heatmap (V, V) # Decode the tour based on the heatmap ifself.decoding_type == "greedy": returnself._greedy_decode(heatmap_dense, batch_size, nodes_num) else: raise NotImplementedError(f"Decoding type '{self.decoding_type}' is not supported.")
def_greedy_decode(self, heatmap: np.ndarray, batch_size: int, nodes_num: int): """ Args: heatmap: (B, V, V) numpy array representing the heatmap batch_size: int, number of samples in the batch nodes_num: int, number of nodes Returns: tours: (B, V) numpy array representing the decoded tours """ tours = [] # Iterate over each instance for idx inrange(batch_size): tour = [] current = None for _ inrange(nodes_num): if current isNone: # Start from the first node next_node = 0 else: # Select the next node with the highest probability next_node = np.argmax(heatmap[idx][current]).item() tour.append(next_node) heatmap[idx][:, next_node] = 0# Remove the selected node current = next_node tour.append(0) # Return to the starting node tours.append(np.array(tour)) return np.array(tours)
defshared_step(self, batch, batch_idx, phase): """ Shared step for training, validation, and testing. """ self.env.mode = phase # unpack batch data x, e, edge_index, ground_truth, ref_tour = batch # x: (B, V, H), e: (B, E, H) # edge_index: (B, 2, E), ground_truth: (B, E) # ref_tour: (B, V+1) e_pred = self.model(x, e, edge_index) # shape: (B, E, 2) loss = nn.CrossEntropyLoss()(e_pred.view(-1, 2), ground_truth.view(-1)) if phase == "val": e_prob = torch.softmax(e_pred, dim=-1) # shape: (B, E, 2) heatmap = e_prob[:, :, 1] # shape: (B, E) tours = self.decoder.decode(heatmap, x.shape[1], edge_index) # shape: (B, V+1) costs_avg, _, gap_avg, _ = self.evaluate(x, tours, ref_tour) # log metrics = {f"{phase}/loss": loss} # print(f"{phase} loss: {loss.item()}") if phase == "val": metrics.update({"val/costs_avg": costs_avg, "val/gap_avg": gap_avg}) for k, v in metrics.items(): self.log(k, float(v), prog_bar=True, on_epoch=True, sync_dist=True) # return return loss if phase == "train"else metrics
defevaluate(self, x: Tensor, tours: Tensor, ref_tour: Tensor): """ Evaluate the model's performance on a given set of tours. Args: x: (B, V, H) tensor representing node features. tours: (B, V) tensor representing predicted tours. ref_tour: (B, V) tensor representing reference tours. Returns: costs_avg: Average cost of the predicted tours. ref_costs_avg: Average cost of the reference tours. gap_avg: Average gap between predicted and reference tours. gap_std: Standard deviation of the gap. """ x = x.cpu().numpy() ref_tour = ref_tour.cpu().numpy()
# gnn/train.py from .env import GNNEnv from .encoder import GCNEncoder from .decoder import GNNDecoder from .model import GNNModel from ml4co_kit import Trainer