Source code for rindti.losses.lifted_structure

from typing import Dict

import torch
from pytorch_lightning import LightningModule
from torch import LongTensor, Tensor


[docs]class GeneralisedLiftedStructureLoss(LightningModule): r"""Gerneralised lifted structure loss. `[paper] <https://arxiv.org/abs/1511.06452>`_ Args: pos_margin (int, optional): Positive margin. Defaults to 0. neg_margin (int, optional): Negative margin. Defaults to 1. """ def __init__(self, pos_margin: int = 0, neg_margin: int = 1, **kwargs) -> None: super().__init__() self.pos_margin = pos_margin self.neg_margin = neg_margin
[docs] def forward(self, embeds: Tensor, fam_idx: LongTensor) -> Dict[str, Tensor]: dist = torch.cdist(embeds, embeds) fam_mask = (fam_idx == fam_idx.t()).float() pos = (dist - self.pos_margin) * fam_mask neg = (self.neg_margin - dist) * (1 - fam_mask) pos_loss = torch.logsumexp(pos, dim=0) neg_loss = torch.logsumexp(neg, dim=0) loss = torch.relu(pos_loss + neg_loss) return dict(graph_loss=loss)