Source code for rindti.models.dti.baseline.prot_drug_max_likelihood

import numpy as np
import pandas as pd

from .base_baseline import BaseBaseline


[docs]class ProtDrugMax(BaseBaseline): """Take the average of the drug and prot labels.""" def __init__(self, which: str = "both", prob: bool = False): super().__init__(prob) self.which = which
[docs] def fit(self, train: pd.DataFrame): """Fit the model to the training dataframe. Has to have 'Drug_ID', 'Target_ID' and 'Y' columns.""" train["count"] = 1 self.protgroup = train.groupby("Target_ID").agg({"Y": ["mean", "std"], "count": "sum"}).loc[:, "Y"] self.druggroup = train.groupby("Drug_ID").agg({"Y": ["mean", "std"], "count": "sum"}).loc[:, "Y"]
def _get_val(self, group: pd.DataFrame, id: str) -> float: if id not in group.index: return 0.5 elif self.prob: return np.random.normal(loc=group.at[id, "mean"], scale=group.at[id, "std"]) return group.at[id, "mean"]
[docs] def predict_pair(self, prot_id: str, drug_id: str) -> float: """Predict the outcome for a pair of a protein and a drug.""" if self.which == "both": return (self._get_val(self.protgroup, prot_id) + self._get_val(self.druggroup, drug_id)) / 2 elif self.which == "prot": return self._get_val(self.protgroup, prot_id) elif self.which == "drug": return self._get_val(self.druggroup, drug_id) else: raise ValueError(f"Unknown value for which: {self.which}")