rindti.models¶
Contents
Base Model¶
The base model defines a lot of common methods that are identical for all models - logging, saving, etc.
- class BaseModel(**kwargs)[source]¶
Bases:
pytorch_lightning.core.lightning.LightningModuleBase model, defines a lot of helper functions.
- training_step(data: rindti.data.data.TwoGraphData, data_idx: int) dict[source]¶
What to do during training step.
- validation_step(data: rindti.data.data.TwoGraphData, data_idx: int) dict[source]¶
What to do during validation step. Also logs the values for various callbacks.
- test_step(data: rindti.data.data.TwoGraphData, data_idx: int) dict[source]¶
What to do during test step. Also logs the values for various callbacks.
- training_epoch_end(outputs: dict)[source]¶
What to do at the end of a training epoch. Logs everything.
DTI models¶
Drug-target interaction prediction models. Calculate embeddings for drugs and proteins, then use an MLP to predict the final result.
- class ClassificationModel(**kwargs)[source]¶
Bases:
rindti.models.base_model.BaseModelModel for DTI prediction as a classification problem.
- forward(prot: dict, drug: dict) torch.Tensor[source]¶
- class RegressionModel(**kwargs)[source]¶
Bases:
rindti.models.dti.classification.ClassificationModelModel for DTI prediction as a reg problem.
Baseline models¶
These models are used to predict the baseline values for datasets. They do not have access to the actual features and operate solely on the labels of drugs and proteins.
- class BaseBaseline(prob: bool = False, **kwargs)[source]¶
Bases:
objectParent of all baseline models.
- fit(train: pandas.core.frame.DataFrame)[source]¶
Fit the model to the training dataframe. Has to have ‘Drug_ID’, ‘Target_ID’ and ‘Y’ columns.
- predict_pair(prot_id: str, drug_id: str) float[source]¶
Predict the outcome for a pair of a protein and a drug.
- test_metrics(test: pandas.core.frame.DataFrame) dict[source]¶
Calculate the metrics for the test dataframe.
- predict(test: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame[source]¶
Apply prediction to the whole test dataframe.
- class Max(prob: bool = False, **kwargs)[source]¶
Bases:
rindti.models.dti.baseline.base_baseline.BaseBaselineTake the most popular label.
- fit(train: pandas.core.frame.DataFrame)[source]¶
Fit the model to the training dataframe. Has to have ‘Drug_ID’, ‘Target_ID’ and ‘Y’ columns.
- predict_pair(prot_id: str, drug_id: str) float[source]¶
Predict the outcome for a pair of a protein and a drug.
- predict(test: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame[source]¶
Apply prediction to the whole test dataframe.
- class ProtDrugMax(which: str = 'both', prob: bool = False)[source]¶
Bases:
rindti.models.dti.baseline.base_baseline.BaseBaselineTake the average of the drug and prot labels.
- fit(train: pandas.core.frame.DataFrame)[source]¶
Fit the model to the training dataframe. Has to have ‘Drug_ID’, ‘Target_ID’ and ‘Y’ columns.