pyagc.utils.off_diagonal

off_diagonal(x: Tensor) Tensor[source]

Extract off-diagonal elements from a square matrix.

Returns a flattened view of all off-diagonal elements of a square matrix. This is useful for computing losses or metrics that exclude the diagonal, such as off-diagonal regularization in self-supervised learning.

Parameters:

x (Tensor) – A square matrix of shape (n, n).

Returns:

Flattened tensor of shape (n * (n-1),) containing

all off-diagonal elements in row-major order.

Return type:

Tensor

Raises:

AssertionError – If the input is not a square matrix.

Example

>>> x = torch.tensor([[1, 2, 3],
...                   [4, 5, 6],
...                   [7, 8, 9]])
>>> off_diagonal(x)
tensor([2, 3, 4, 6, 7, 8])

Note

This function is memory-efficient as it returns a view rather than a copy of the data when possible.