rindti.models

Base Model

The base model defines a lot of common methods that are identical for all models - logging, saving, etc.

class BaseModel(**kwargs)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

Base model, defines a lot of helper functions.

training_step(data: rindti.data.data.TwoGraphData, data_idx: int) dict[source]

What to do during training step.

validation_step(data: rindti.data.data.TwoGraphData, data_idx: int) dict[source]

What to do during validation step. Also logs the values for various callbacks.

test_step(data: rindti.data.data.TwoGraphData, data_idx: int) dict[source]

What to do during test step. Also logs the values for various callbacks.

log_histograms()[source]

Logs the histograms of all the available parameters.

log_all(metrics: dict, hparams: bool = False)[source]

Log all metrics.

training_epoch_end(outputs: dict)[source]

What to do at the end of a training epoch. Logs everything.

validation_epoch_end(outputs: dict)[source]

What to do at the end of a validation epoch. Logs everything.

test_epoch_end(outputs: dict)[source]

What to do at the end of a test epoch. Logs everything.

configure_optimizers() Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler._LRScheduler][source]

Configure the optimizer and/or lr schedulers

DTI models

Drug-target interaction prediction models. Calculate embeddings for drugs and proteins, then use an MLP to predict the final result.

class ClassificationModel(**kwargs)[source]

Bases: rindti.models.base_model.BaseModel

Model for DTI prediction as a classification problem.

forward(prot: dict, drug: dict) torch.Tensor[source]
class RegressionModel(**kwargs)[source]

Bases: rindti.models.dti.classification.ClassificationModel

Model for DTI prediction as a reg problem.

Baseline models

These models are used to predict the baseline values for datasets. They do not have access to the actual features and operate solely on the labels of drugs and proteins.

class BaseBaseline(prob: bool = False, **kwargs)[source]

Bases: object

Parent of all baseline models.

fit(train: pandas.core.frame.DataFrame)[source]

Fit the model to the training dataframe. Has to have ‘Drug_ID’, ‘Target_ID’ and ‘Y’ columns.

predict_pair(prot_id: str, drug_id: str) float[source]

Predict the outcome for a pair of a protein and a drug.

test_metrics(test: pandas.core.frame.DataFrame) dict[source]

Calculate the metrics for the test dataframe.

predict(test: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame[source]

Apply prediction to the whole test dataframe.

assess_dataset(filename: str, train_frac: float = 0.8, n_runs: int = 10)[source]

Assess the performance of the model on a dataset.

class Max(prob: bool = False, **kwargs)[source]

Bases: rindti.models.dti.baseline.base_baseline.BaseBaseline

Take the most popular label.

fit(train: pandas.core.frame.DataFrame)[source]

Fit the model to the training dataframe. Has to have ‘Drug_ID’, ‘Target_ID’ and ‘Y’ columns.

predict_pair(prot_id: str, drug_id: str) float[source]

Predict the outcome for a pair of a protein and a drug.

predict(test: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame[source]

Apply prediction to the whole test dataframe.

class ProtDrugMax(which: str = 'both', prob: bool = False)[source]

Bases: rindti.models.dti.baseline.base_baseline.BaseBaseline

Take the average of the drug and prot labels.

fit(train: pandas.core.frame.DataFrame)[source]

Fit the model to the training dataframe. Has to have ‘Drug_ID’, ‘Target_ID’ and ‘Y’ columns.

predict_pair(prot_id: str, drug_id: str) float[source]

Predict the outcome for a pair of a protein and a drug.