rindti.losses

class GeneralisedLiftedStructureLoss(pos_margin: int = 0, neg_margin: int = 1, **kwargs)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Gerneralised lifted structure loss.

[paper]

Parameters
  • pos_margin (int, optional) – Positive margin. Defaults to 0.

  • neg_margin (int, optional) – Negative margin. Defaults to 1.

forward(embeds: torch.Tensor, fam_idx: torch.LongTensor) Dict[str, torch.Tensor][source]

Same as torch.nn.Module.forward().

Parameters
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns

Your model’s output

class SoftNearestNeighborLoss(temperature: float = 1.0, eps: float = 1e-06, optim_temperature: bool = False, grad_step: float = 0.2, **kwargs)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Soft Nearest Neighbor Loss.

[paper] <https://arxiv.org/pdf/1902.01889.pdf>_

Parameters
  • temperature (float, optional) – Temperature. Defaults to 1.

  • eps (float, optional) – Epsilon. Defaults to 1e-6.

  • optim_temperature (bool, optional) – Whether to optimise temperature. Defaults to False.

  • grad_step (float, optional) – Gradient step for temperature optimisation. Defaults to 0.2.