Source code for rindti.data.datamodules

from pytorch_lightning import LightningDataModule
from torch.utils.data.sampler import Sampler
from torch_geometric.loader import DataLoader

from ..utils import split_random
from .datasets import DTIDataset, PreTrainDataset


class BaseDataModule(LightningDataModule):
    """Base data module, contains all the datasets for train, val and test."""

    def __init__(
        self, filename: str, exp_name: str, batch_size: int = 128, num_workers: int = 1, shuffle: bool = True
    ):
        super().__init__()
        self.filename = filename
        self.exp_name = exp_name
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle

    def update_config(self, config: dict) -> None:
        raise NotImplementedError

    def train_dataloader(self):
        return DataLoader(self.train, **self._dl_kwargs(True))

    def val_dataloader(self):
        return DataLoader(self.val, **self._dl_kwargs(False))

    def test_dataloader(self):
        return DataLoader(self.test, **self._dl_kwargs(False))

    def predict_dataloader(self):
        return DataLoader(self.test, **self._dl_kwargs(False))


[docs]class DTIDataModule(BaseDataModule): """Data module for the DTI dataset."""
[docs] def setup(self, stage: str = None): """Load the individual datasets""" self.train = DTIDataset(self.filename, self.exp_name, split="train").shuffle() self.val = DTIDataset(self.filename, self.exp_name, split="val").shuffle() self.test = DTIDataset(self.filename, self.exp_name, split="test").shuffle() self.config = self.train.config
def _dl_kwargs(self, shuffle: bool = False): return dict( batch_size=self.batch_size, shuffle=self.shuffle if shuffle else False, num_workers=self.num_workers, follow_batch=["prot_x", "drug_x"], )
[docs] def update_config(self, config: dict) -> None: """Update the main config with the config of the dataset.""" print(self.config) for i in ["prot", "drug"]: config["model"][i]["data"] = self.config["snakemake"]["data"][i]
[docs]class PreTrainDataModule(BaseDataModule): """DataModule for pretraining on prots.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def setup(self, stage: str = None): """Load the individual datasets.""" ds = PreTrainDataset(self.filename) self.train, self.val, self.test = split_random(ds, [0.7, 0.2, 0.1]) self.config = ds.config
def _dl_kwargs(self, shuffle: bool = False): return dict( batch_size=self.batch_size, shuffle=self.shuffle if shuffle else False, num_workers=self.num_workers, )
[docs] def update_config(self, config: dict) -> None: """Update the main config with the config of the dataset.""" config["model"]["encoder"]["data"] = self.config["data"] config["model"]["num_classes"] = self.config["data"]["num_classes"]