Source code for rindti.losses.snnl

from typing import Union

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch import LongTensor, Tensor


[docs]class SoftNearestNeighborLoss(LightningModule): """Soft Nearest Neighbor Loss. `[paper] <https://arxiv.org/pdf/1902.01889.pdf>_` Args: temperature (float, optional): Temperature. Defaults to 1. eps (float, optional): Epsilon. Defaults to 1e-6. optim_temperature (bool, optional): Whether to optimise temperature. Defaults to False. grad_step (float, optional): Gradient step for temperature optimisation. Defaults to 0.2. """ def __init__( self, temperature: float = 1.0, eps: float = 1e-6, optim_temperature: bool = False, grad_step: float = 0.2, **kwargs, ): super().__init__() self.temperature = temperature self.eps = eps self.optim_temperature = optim_temperature self.grad_step = grad_step def _forward(self, embeds: Tensor, fam_idx: LongTensor, temp_frac: Union[int, Tensor]) -> Tensor: """Calculate the soft nearest neighbor loss for a given temp denominator.""" embeds = F.normalize(embeds) sim = 1 - torch.matmul(embeds, embeds.t()) expsim = torch.exp(-sim / (self.temperature / temp_frac)) * (1 - torch.eye(sim.size(0), device=self.device)) f = expsim / (self.eps + expsim.sum(dim=1)) fam_mask = (fam_idx == fam_idx.t()).float() f = f * fam_mask print(f.shape) loss = -torch.log(self.eps + f.sum(dim=1)) print(loss) return dict(graph_loss=loss) def forward(self, embeds: Tensor, fam_idx: LongTensor) -> Tensor: """""" if not self.optim_temperature: return self._forward(embeds, fam_idx, 1.0) temp_frac = torch.tensor(1, device=self.device, dtype=torch.float32, requires_grad=True) loss = self._forward(embeds, fam_idx, temp_frac) loss.mean().backward(inputs=[temp_frac]) with torch.no_grad(): temp_frac -= self.grad_step * temp_frac.grad return self._forward(embeds, fam_idx, temp_frac)