Source code for rindti.models.dti.classification

import torch.nn.functional as F
from torch.functional import Tensor

from ...data import TwoGraphData
from ...layers.encoder import GraphEncoder, PretrainedEncoder, SweetNetEncoder
from ...layers.other import MLP
from ...utils import remove_arg_prefix
from ..base_model import BaseModel

encoders = {"graph": GraphEncoder, "sweetnet": SweetNetEncoder, "pretrained": PretrainedEncoder}


[docs]class ClassificationModel(BaseModel): """Model for DTI prediction as a classification problem.""" def __init__(self, **kwargs): super().__init__(**kwargs) self._determine_feat_method( kwargs["model"]["feat_method"], kwargs["model"]["prot"]["hidden_dim"], kwargs["model"]["drug"]["hidden_dim"], ) self.prot_encoder = encoders[kwargs["model"]["prot"]["method"]](**kwargs["model"]["prot"]) self.drug_encoder = encoders[kwargs["model"]["drug"]["method"]](**kwargs["model"]["drug"]) self.mlp = MLP(input_dim=self.embed_dim, out_dim=1, **kwargs["model"]["mlp"]) self._set_class_metrics()
[docs] def forward(self, prot: dict, drug: dict) -> Tensor: """""" prot_embed = self.prot_encoder(prot) drug_embed = self.drug_encoder(drug) joint_embedding = self.merge_features(drug_embed, prot_embed) return dict( pred=self.mlp(joint_embedding), prot_embed=prot_embed, drug_embed=drug_embed, joint_embed=joint_embedding, )
def shared_step(self, data: TwoGraphData) -> dict: """Step that is the same for train, validation and test. Returns: dict: dict with different metrics - losses, accuracies etc. Has to contain 'loss'. """ prot = remove_arg_prefix("prot_", data) drug = remove_arg_prefix("drug_", data) fwd_dict = self.forward(prot, drug) labels = data.label.unsqueeze(1) bce_loss = F.binary_cross_entropy_with_logits(fwd_dict["pred"], labels.float()) return dict(loss=bce_loss, preds=fwd_dict["pred"].detach(), labels=labels.detach())