Source code for pyagc.transforms.gssl_transform

from typing import List, Optional, Union

import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


[docs]@functional_transform('gssl_transform') class GSSLTransform(BaseTransform): r"""Applies random feature masking and random edge dropping for Graph Self-Supervised Learning (functional name: :obj:`gssl_transform`). This transform is commonly used in graph self-supervised learning methods such as `GRACE <https://arxiv.org/abs/2006.04131>`_, `CCA-SSG <https://arxiv.org/abs/2106.12484>`_, and `BGRL <https://arxiv.org/abs/2102.06514>`_. For each node attribute in :obj:`node_attrs`, randomly masks features. For each edge attribute in :obj:`edge_attrs`, randomly drops edges. Works for both homogeneous and heterogeneous graphs. Only keeps specified node attributes and edge attributes in the returned data. Args: p_feat_mask (float, optional): Probability of masking node features. (default: :obj:`0.5`) p_edge_drop (float, optional): Probability of dropping edges. (default: :obj:`0.5`) node_attrs (List[str], optional): Node attributes to transform and keep. (default: :obj:`["x"]`) edge_attrs (List[str], optional): Edge attributes to transform and keep. (default: :obj:`["edge_attr"]`) """
[docs] def __init__( self, p_feat_mask: float = 0.5, p_edge_drop: float = 0.5, node_attrs: Optional[List[str]] = ["x"], edge_attrs: Optional[List[str]] = ["edge_attr"], ): for p in (p_feat_mask, p_edge_drop): if p < 0. or p > 1.: raise ValueError(f'Masking ratio has to be between 0 and 1 ' f'(got {p}') self.p_feat_mask = p_feat_mask self.p_edge_drop = p_edge_drop self.node_attrs = node_attrs self.edge_attrs = edge_attrs
def _mask_features(self, x: torch.Tensor) -> torch.Tensor: if self.p_feat_mask == 0.0 or x.numel() == 0: return x mask = torch.rand_like(x) < self.p_feat_mask x = x.clone() x[mask] = 0 return x def _drop_edges( self, edge_index: torch.Tensor, edge_attrs: List[Optional[torch.Tensor]] ) -> (torch.Tensor, List[Optional[torch.Tensor]]): if self.p_edge_drop == 0.0 or edge_index.numel() == 0: return edge_index, edge_attrs num_edges = edge_index.size(1) mask = torch.rand(num_edges) >= self.p_edge_drop edge_index = edge_index[:, mask] new_edge_attrs = [] for edge_attr in edge_attrs: if edge_attr is not None: new_edge_attrs.append(edge_attr[mask]) else: new_edge_attrs.append(None) return edge_index, new_edge_attrs def __call__(self, *args, **kwargs) -> Union[dict, Data, HeteroData]: r""" Supports both Data object and separate arguments. If called with a Data object: transform(data) -> Data If called with separate arguments: transform(x, edge_index, ...) -> dict with transformed values """ # Case 1: Called with Data/HeteroData object if len(args) == 1 and isinstance(args[0], (Data, HeteroData)): return self.forward(args[0]) # Case 2: Called with separate arguments (x, edge_index, ...) # Reconstruct from args and kwargs result = {} # Handle positional args (assumed to be x, edge_index in order) if len(args) >= 1: x = args[0] result['x'] = self._mask_features(x) if len(args) >= 2: edge_index = args[1] edge_attrs_values = [kwargs.get(attr) for attr in self.edge_attrs] edge_index, new_edge_attrs = self._drop_edges(edge_index, edge_attrs_values) result['edge_index'] = edge_index for attr, value in zip(self.edge_attrs, new_edge_attrs): if value is not None: result[attr] = value # Handle kwargs for key, value in kwargs.items(): if key == 'x' and 'x' not in result: result['x'] = self._mask_features(value) elif key == 'edge_index' and 'edge_index' not in result: edge_attrs_values = [kwargs.get(attr) for attr in self.edge_attrs] edge_index, new_edge_attrs = self._drop_edges(value, edge_attrs_values) result['edge_index'] = edge_index for attr, val in zip(self.edge_attrs, new_edge_attrs): if val is not None: result[attr] = val elif key not in self.edge_attrs and key not in result: result[key] = value return result
[docs] def forward(self, data: Union[Data, HeteroData]) -> Union[Data, HeteroData]: out = data.__class__() if isinstance(data, Data): # Mask node attributes for attr in self.node_attrs: if hasattr(data, attr): value = getattr(data, attr) out[attr] = self._mask_features(value) # Drop edges and corresponding edge attributes edge_attrs_values = [getattr(data, attr, None) for attr in self.edge_attrs] edge_index, new_edge_attrs = self._drop_edges( data.edge_index, edge_attrs_values ) out.edge_index = edge_index for attr, value in zip(self.edge_attrs, new_edge_attrs): if value is not None: out[attr] = value elif isinstance(data, HeteroData): for node_type in data.node_types: for attr in self.node_attrs: if attr in data[node_type]: out[node_type][attr] = self._mask_features(data[node_type][attr]) for edge_type in data.edge_types: edge_attrs_values = [data[edge_type].get(attr, None) for attr in self.edge_attrs] edge_index, new_edge_attrs = self._drop_edges( data[edge_type].edge_index, edge_attrs_values ) out[edge_type].edge_index = edge_index for attr, value in zip(self.edge_attrs, new_edge_attrs): if value is not None: out[edge_type][attr] = value return out