pyagc.transforms.GSSLTransform
- class GSSLTransform(p_feat_mask: float = 0.5, p_edge_drop: float = 0.5, node_attrs: Optional[List[str]] = ['x'], edge_attrs: Optional[List[str]] = ['edge_attr'])[source]
Bases:
BaseTransformApplies random feature masking and random edge dropping for Graph Self-Supervised Learning (functional name:
gssl_transform).This transform is commonly used in graph self-supervised learning methods such as GRACE, CCA-SSG, and BGRL.
For each node attribute in
node_attrs, randomly masks features. For each edge attribute inedge_attrs, randomly drops edges.Works for both homogeneous and heterogeneous graphs.
Only keeps specified node attributes and edge attributes in the returned data.
- Parameters:
p_feat_mask (float, optional) – Probability of masking node features. (default:
0.5)p_edge_drop (float, optional) – Probability of dropping edges. (default:
0.5)node_attrs (List[str], optional) – Node attributes to transform and keep. (default:
["x"])edge_attrs (List[str], optional) – Edge attributes to transform and keep. (default:
["edge_attr"])
- __init__(p_feat_mask: float = 0.5, p_edge_drop: float = 0.5, node_attrs: Optional[List[str]] = ['x'], edge_attrs: Optional[List[str]] = ['edge_attr'])[source]
Methods
__init__([p_feat_mask, p_edge_drop, ...])forward(data)- rtype:
Union[Data,HeteroData]
- forward(data: Union[Data, HeteroData]) Union[Data, HeteroData][source]
- Return type:
Union[Data,HeteroData]