rindti.losses¶
- class GeneralisedLiftedStructureLoss(pos_margin: int = 0, neg_margin: int = 1, **kwargs)[source]¶
Bases:
pytorch_lightning.core.lightning.LightningModuleGerneralised lifted structure loss.
- Parameters
- forward(embeds: torch.Tensor, fam_idx: torch.LongTensor) Dict[str, torch.Tensor][source]¶
Same as
torch.nn.Module.forward().- Parameters
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns
Your model’s output