Data

TwoGraphData

A subclass of torch_geometric.data.Data, that handles an entry of two graphs.

The two graph are indicated by certain prefix, thus x and edge_index become drug_x and drug_edge_index.

This has an effect on batching during training and prediction, since the prot_edge_index entry gets modified according to the standard processing rules of torch_geometric

Otherwise it is just a dictionary.

from rindti.data import TwoGraphData
import torch
num_prot_nodes, num_drug_nodes = 100, 30
num_prot_edges, num_drug_edges = 100, 30
prot_x = torch.rand(num_prot_nodes, 32)
drug_x = torch.rand(num_drug_nodes, 32)
prot_edge_index = torch.randint(0, num_prot_nodes, (2, num_prot_edges))
drug_edge_index = torch.randint(0, num_drug_nodes, (2, num_drug_edges))
tgd = TwoGraphData(prot_x=prot_x, drug_x=drug_x, prot_edge_index=prot_edge_index, drug_edge_index=drug_edge_index)
print(tgd)
>>> TwoGraphData(prot_x=[100, 32], drug_x=[30, 32], prot_edge_index=[2, 100], drug_edge_index=[2, 30])

Then such objects can be given directly to the dataloader with

from torch_geometric.loader import DataLoader
dl = DataLoader([tgd] * 10, batch_size=5, num_workers=1) # just take 10 times the same graph for simplicity
batch = next(iter(dl))
print(batch)
>>> TwoGraphDataBatch(prot_x=[500, 32], drug_x=[150, 32], prot_edge_index=[2, 500], drug_edge_index=[2, 150])

Datasets

Custom datasets are based on torch_geometric Datasets

They are designed to take in the results of the snakemake workflows, and create a quick-to-load pytorch objects.

DTI datasets

Since the splits in DTI predictions are often non-random (scaffold split/cold target split), for each DTI pair a string indicating the split is provided.

Thus to create DTI datasets one needs to specialize the split:

from rindti.data import DTIDataset
pickle_file = "filename.pkl"
train = DTIDataset(pickle_file, split="train")
val = DTIDataset(pickle_file, split="val")
test = DTIDataset(pickle_file, split="test")

Pretraining datasets

The datasets for pretraining are also obtained from the snakemake workflow, however, the splitting is done internally.

pickle_file = "filename.pkl"
dataset = PreTrainDataset(pickle_file)

DataModules

Datamodules are based on pytorch_lightning DataModules and aim to put all data-related functionality (dataloaders, splitting, sampling) into a single object. Can be invoked simply with:

from rindti.data import DTIDataModule, PreTrainDataModule
dti_pickle_file = "dti.pkl"
dti_dm = DTIDataModule(dti_pickle_file)
pretrain_pickle_file = "pretrain.pkl"
pretrain_dm = PreTrainDataModule(pretrain_pickle_file)