Creating Custom Cluster Heads
This tutorial demonstrates how to create custom cluster heads in PyAGC. We’ll walk through the implementation of a new clustering mechanism and integrate it into the training pipeline.
Overview
PyAGC provides a modular architecture where cluster heads are standalone components that can be easily swapped. All cluster heads inherit from BaseClusterHead and implement two key methods:
forward(): Computes the clustering loss during trainingcluster(): Generates cluster assignments during inference
The BaseClusterHead Interface
from pyagc.clusters import BaseClusterHead
import torch.nn as nn
class BaseClusterHead(nn.Module):
def forward(self, *args, **kwargs):
"""Compute clustering loss for training"""
raise NotImplementedError
def cluster(self, z, soft=False):
"""
Generate cluster assignments
Args:
z: Node embeddings (N, F)
soft: If True, return probabilities; if False, return hard labels
Returns:
Cluster assignments (N,) or probabilities (N, K)
"""
raise NotImplementedError
Example 1: Implementing a Simple Distance-Based Cluster Head
Let’s create a simple cluster head that assigns nodes to the nearest cluster center:
import torch
import torch.nn as nn
from pyagc.clusters import BaseClusterHead
class SimpleClusterHead(BaseClusterHead):
def __init__(self, n_clusters, n_features):
super().__init__()
self.n_clusters = n_clusters
self.n_features = n_features
# Learnable cluster centers
self.cluster_centers = nn.Parameter(
torch.empty(n_clusters, n_features)
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.cluster_centers)
def forward(self, z):
"""
Compute clustering loss as mean squared distance
Args:
z: Node embeddings (N, F)
Returns:
Scalar loss
"""
# Compute pairwise distances: (N, K)
dist = torch.cdist(z, self.cluster_centers, p=2)
# Assign to nearest cluster
assignments = dist.argmin(dim=-1) # (N,)
# Loss: mean distance to assigned cluster
loss = dist[torch.arange(len(z)), assignments].mean()
return loss
@torch.no_grad()
def cluster(self, z, soft=False):
"""Generate cluster assignments"""
dist = torch.cdist(z, self.cluster_centers, p=2)
if soft:
# Convert distances to similarities via softmax
sim = -dist # Negate so smaller distance = higher similarity
return torch.softmax(sim, dim=-1)
else:
return dist.argmin(dim=-1)
Example 2: Integrating with the DMoN Model
Now let’s see how to use our custom cluster head with an existing model. We’ll modify the DMoN model to use our SimpleClusterHead:
from pyagc.models import DMoN
from pyagc.encoders import GCN
# Create encoder
encoder = GCN(
in_channels=dataset.num_features,
hidden_channels=64,
num_layers=2,
)
# Create model with custom cluster head
model = DMoN(
encoder=encoder,
n_features=64,
n_clusters=7
)
# Replace the default cluster head
model.cluster_head = SimpleClusterHead(
n_clusters=7,
n_features=64
)
Example 3: Creating a Graph-Aware Cluster Head
For more sophisticated clustering, we can create a cluster head that considers graph structure. Here’s an example inspired by DMoN:
from torch_geometric.utils import degree
class GraphAwareClusterHead(BaseClusterHead):
def __init__(self, n_clusters, n_features, alpha=1.0):
super().__init__()
self.n_clusters = n_clusters
self.alpha = alpha # Trade-off parameter
self.cluster_centers = nn.Parameter(
torch.empty(n_clusters, n_features)
)
nn.init.xavier_uniform_(self.cluster_centers)
def forward(self, z, edge_index):
"""
Combine embedding distance with graph structure
Args:
z: Node embeddings (N, F)
edge_index: Graph edges (2, E)
Returns:
Scalar loss
"""
N = z.size(0)
# Soft assignment matrix
S = torch.matmul(z, self.cluster_centers.t()) # (N, K)
S = torch.softmax(S, dim=-1)
# Embedding loss: distance to assigned cluster
dist = torch.cdist(z, self.cluster_centers, p=2)
assignments = S.argmax(dim=-1)
embed_loss = dist[torch.arange(N), assignments].mean()
# Structure loss: edge homogeneity
# Edges should connect nodes in the same cluster
src, dst = edge_index
structure_loss = -torch.sum(S[src] * S[dst]) / edge_index.size(1)
# Combined loss
loss = embed_loss + self.alpha * structure_loss
return loss
@torch.no_grad()
def cluster(self, z, soft=False):
sim = torch.matmul(z, self.cluster_centers.t())
if soft:
return torch.softmax(sim, dim=-1)
else:
return sim.argmax(dim=-1)
Example 4: Using Custom Cluster Head in Training
Here’s a complete example showing how to train a model with a custom cluster head:
import torch
from torch_geometric.data import Data
from pyagc.data import get_dataset
from pyagc.encoders import GCN
from pyagc.models import ClusteringModel
# Load dataset
x, edge_index, y = get_dataset('Cora')
data = Data(x=x, edge_index=edge_index)
# Create model
encoder = GCN(in_channels=x.size(1), hidden_channels=64, num_layers=2)
cluster_head = GraphAwareClusterHead(n_clusters=7, n_features=64)
# Simple wrapper model
class MyClusteringModel(ClusteringModel):
def __init__(self, encoder, cluster_head):
super().__init__()
self.encoder = encoder
self.cluster_head = cluster_head
def forward(self, data):
z = self.encoder(data.x, data.edge_index)
loss = self.cluster_head(z, data.edge_index)
return loss
def infer_full(self, data):
z = self.encoder(data.x, data.edge_index)
return self.cluster_head.cluster(z, soft=False)
model = MyClusteringModel(encoder, cluster_head)
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
optimizer.zero_grad()
loss = model(data)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss.item():.4f}')
# Inference
model.eval()
pred = model.infer_full(data)
print(f'Predicted clusters: {pred}')
Comparing with Existing Implementations
Let’s compare our simple implementation with PyAGC’s built-in DECClusterHead:
from pyagc.clusters import DECClusterHead
# Our simple cluster head
simple_head = SimpleClusterHead(n_clusters=7, n_features=64)
# PyAGC's DEC cluster head
dec_head = DECClusterHead(n_clusters=7, n_features=64, alpha=1.0)
# Both can be used interchangeably
z = torch.randn(100, 64) # Random embeddings
# Compute loss
loss_simple = simple_head(z)
loss_dec = dec_head(z)
# Get predictions
pred_simple = simple_head.cluster(z, soft=False)
pred_dec = dec_head.cluster(z, soft=False)
print(f'Simple loss: {loss_simple:.4f}, DEC loss: {loss_dec:.4f}')
Key Takeaways
Inherit from BaseClusterHead: This ensures compatibility with PyAGC’s training pipelines
Implement forward() and cluster(): These are the only required methods
Use learnable parameters: Initialize cluster centers or other parameters as
nn.ParameterSupport soft and hard assignments: The
cluster()method should handle both casesConsider graph structure: For attributed graph clustering, incorporating edge information often improves results
Next Steps
Check out the benchmark scripts for complete training examples
Experiment with different loss formulations and assignment strategies for your specific use case
Scale your cluster head to large graphs with mini-batch training
Review existing implementations in the clusters module
Check out the ECO framework tutorial for design patterns