Source code for pyagc.models.node2vec

from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_CLUSTER
from torch_geometric.utils import sort_edge_index
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils.sparse import index2ptr

from pyagc.models.base import TrainableModel, LossOutput


[docs]class Node2Vec(TrainableModel): r"""The Node2Vec model from the `"node2vec: Scalable Feature Learning for Networks" <https://arxiv.org/abs/1607.00653>`_ paper where random walks of length :obj:`walk_length` are sampled in a given graph, and node embeddings are learned via negative sampling optimization. Args: edge_index (torch.Tensor): The edge indices. embedding_dim (int): The size of each embedding vector. walk_length (int): The walk length. context_size (int): The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes. walks_per_node (int, optional): The number of walks to sample for each node. (default: :obj:`1`) p (float, optional): Likelihood of immediately revisiting a node in the walk. (default: :obj:`1`) q (float, optional): Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: :obj:`1`) num_negative_samples (int, optional): The number of negative samples to use for each positive sample. (default: :obj:`1`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the weight matrix will be sparse. (default: :obj:`False`) cpu_embedding (bool, optional): If set to :obj:`True`, stores embeddings on CPU and only moves required embeddings to GPU during training. Essential for very large graphs. (default: :obj:`False`) """ def __init__( self, edge_index: Tensor, embedding_dim: int, walk_length: int, context_size: int, walks_per_node: int = 1, p: float = 1.0, q: float = 1.0, num_negative_samples: int = 1, num_nodes: Optional[int] = None, sparse: bool = False, cpu_embedding: bool = False, ): super().__init__() if WITH_PYG_LIB and p == 1.0 and q == 1.0: self.random_walk_fn = torch.ops.pyg.random_walk elif WITH_TORCH_CLUSTER: self.random_walk_fn = torch.ops.torch_cluster.random_walk else: if p == 1.0 and q == 1.0: raise ImportError(f"'{self.__class__.__name__}' " f"requires either the 'pyg-lib' or " f"'torch-cluster' package") else: raise ImportError(f"'{self.__class__.__name__}' " f"requires the 'torch-cluster' package") if walk_length < context_size: raise ValueError( f"walk_length ({walk_length}) must be >= context_size ({context_size})" ) self.num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = sort_edge_index(edge_index, num_nodes=self.num_nodes).cpu() self.rowptr, self.col = index2ptr(row, self.num_nodes), col self.EPS = 1e-15 self.embedding_dim = embedding_dim self.walk_length = walk_length - 1 self.context_size = context_size self.walks_per_node = walks_per_node self.p = p self.q = q self.num_negative_samples = num_negative_samples self.cpu_embedding = cpu_embedding self.compute_device = torch.device('cpu') self.embedding = Embedding( self.num_nodes, embedding_dim, sparse=sparse ) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.embedding.reset_parameters()
[docs] def to(self, device): """Override to method to handle CPU embedding mode.""" if self.cpu_embedding: # Keep embeddings on CPU, but remember the compute device self.compute_device = torch.device(device) if not isinstance(device, torch.device) else device # Don't move embedding parameters return self else: # Standard behavior: move everything to device self.compute_device = torch.device(device) if not isinstance(device, torch.device) else device return super().to(device)
[docs] def embed(self, batch: Optional[Tensor] = None, device: Optional[torch.device] = None) -> Tensor: r""" Returns the embeddings for the nodes in :obj:`batch`. Args: batch (torch.Tensor, optional): Node indices. If :obj:`None`, returns embeddings for all nodes. (default: :obj:`None`) device (torch.device, optional): Target device for the embeddings. If None, uses self.compute_device for cpu_embedding mode, or the embedding's device for standard mode. (default: :obj:`None`) Returns: Node embeddings of shape :obj:`(num_nodes, embedding_dim)` or :obj:`(batch_size, embedding_dim)`. """ # Determine target device if device is None: if self.cpu_embedding: # For CPU embedding mode, default to compute_device for consistency target_device = self.compute_device else: # For standard mode, use embedding's device target_device = self.embedding.weight.device else: target_device = device if self.cpu_embedding: # For CPU embedding mode, we need to handle this differently if batch is None: # Return all embeddings emb = self.embedding.weight.data # Move to target device if requested if target_device.type == 'cuda' and target_device != torch.device('cpu'): return emb.to(target_device) return emb else: # Get embeddings for batch if batch.device.type != 'cpu': batch_cpu = batch.cpu() else: batch_cpu = batch emb = self.embedding.weight[batch_cpu] # Move to target device if requested if target_device.type == 'cuda' and target_device != torch.device('cpu'): return emb.to(target_device) return emb else: # Standard mode emb = self.embedding.weight if batch is None else self.embedding.weight[batch] # Move to target device if different if emb.device != target_device: return emb.to(target_device) return emb
[docs] def loader(self, **kwargs) -> DataLoader: r""" Creates a DataLoader for training Node2Vec. Returns: DataLoader that samples positive and negative random walks. """ return DataLoader(range(self.num_nodes), collate_fn=self.sample, **kwargs)
[docs] @torch.jit.export def pos_sample(self, batch: Tensor) -> Tensor: r"""Samples positive random walks.""" batch = batch.repeat(self.walks_per_node) rw = self.random_walk_fn(self.rowptr, self.col, batch, self.walk_length, self.p, self.q) if not isinstance(rw, Tensor): rw = rw[0] walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0)
[docs] @torch.jit.export def neg_sample(self, batch: Tensor) -> Tensor: r"""Samples negative random walks.""" batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rw = torch.randint(self.num_nodes, (batch.size(0), self.walk_length), dtype=batch.dtype, device=batch.device) rw = torch.cat([batch.view(-1, 1), rw], dim=-1) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0)
[docs] @torch.jit.export def sample(self, batch: Union[List[int], Tensor]) -> Tuple[Tensor, Tensor]: r"""Samples positive and negative random walks for a batch of nodes.""" if not isinstance(batch, Tensor): batch = torch.tensor(batch) return self.pos_sample(batch), self.neg_sample(batch)
[docs] def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> LossOutput: r""" Computes the loss given positive and negative random walks. Args: pos_rw (torch.Tensor): Positive random walks of shape :obj:`(num_walks, context_size)`. neg_rw (torch.Tensor): Negative random walks of shape :obj:`(num_walks, context_size)`. Returns: LossOutput containing total loss and individual components. """ # Positive loss start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous() h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim) h_start, h_rest = h_start.to(self.compute_device), h_rest.to(self.compute_device) out = (h_start * h_rest).sum(dim=-1).view(-1) pos_loss = -torch.log(torch.sigmoid(out) + self.EPS).mean() # Negative loss start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous() h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim) h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim) h_start, h_rest = h_start.to(self.compute_device), h_rest.to(self.compute_device) out = (h_start * h_rest).sum(dim=-1).view(-1) neg_loss = -torch.log(1 - torch.sigmoid(out) + self.EPS).mean() total = pos_loss + neg_loss return LossOutput( total=total, components={ 'pos': pos_loss.item(), 'neg': neg_loss.item() } )
[docs] def train_epoch(self, loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, verbose: bool = True) -> float: r""" Runs one epoch of Node2Vec training using the custom DataLoader. Args: loader (DataLoader): Node2Vec data loader created via :meth:`loader`. optimizer (torch.optim.Optimizer): The optimizer. epoch (int): Current epoch number. verbose (bool, optional): If :obj:`True`, prints training progress. (default: :obj:`True`) Returns: Average loss value of the epoch. """ self.train() total_loss = 0.0 total_pos = 0.0 total_neg = 0.0 num_batches = 0 if verbose: from tqdm import tqdm pbar = tqdm(total=len(loader)) pbar.set_description(f'Epoch {epoch:03d}') for pos_rw, neg_rw in loader: optimizer.zero_grad() # Use the tracked compute device if not self.cpu_embedding: pos_rw = pos_rw.to(self.compute_device) neg_rw = neg_rw.to(self.compute_device) loss_output = self.loss(pos_rw, neg_rw) loss_output.total.backward() optimizer.step() total_loss += loss_output.total.item() total_pos += loss_output.components['pos'] total_neg += loss_output.components['neg'] num_batches += 1 if verbose: pbar.update(1) if verbose: pbar.close() avg_loss = total_loss / num_batches avg_pos = total_pos / num_batches avg_neg = total_neg / num_batches if verbose: log_str = f"Epoch: {epoch:03d} Loss: {avg_loss:.4f}, POS: {avg_pos:.4f}, NEG: {avg_neg:.4f}" self.logger.info(log_str) if self.logger else print(log_str) return avg_loss
def __repr__(self) -> str: emb_location = "CPU" if self.cpu_embedding else "GPU" compute_loc = f"compute_on={self.compute_device}" return (f'{self.__class__.__name__}({self.embedding.weight.size(0)}, ' f'{self.embedding.weight.size(1)}, ' f'embedding_on={emb_location}, {compute_loc})')