Source code for rindti.data.datasets

import os
import pickle
from typing import Callable, Iterable

import pandas as pd
import torch
from torch_geometric.data import Data, Dataset, InMemoryDataset

from .data import TwoGraphData


[docs]class DTIDataset(InMemoryDataset): """Dataset class for prots and drugs. Args: filename (str): Pickle file that stores the data split (str, optional): Split type ('train', 'val', 'test). Defaults to "train". transform (Callable, optional): transformer to apply on each access. Defaults to None. pre_transform (Callable, optional): pre-transformer to apply once before. Defaults to None. """ splits = {"train": 0, "val": 1, "test": 2} def __init__( self, filename: str, exp_name: str, split: str = "train", transform: Callable = None, pre_transform: Callable = None, pre_filter: Callable = None, ): root = self._set_filenames(filename, exp_name) super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices, self.config = torch.load(self.processed_paths[self.splits[split]]) def _set_filenames(self, filename: str, exp_name: str) -> str: basefilename = os.path.basename(filename) basefilename = os.path.splitext(basefilename)[0] self.filename = filename return os.path.join("data", exp_name, basefilename) def process_(self, data_list: list, split: str): if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] data, slices = self.collate(data_list) torch.save((data, slices, self.config), self.processed_paths[self.splits[split]]) def _get_datum(self, all_data: dict, id: str, which: str, **kwargs) -> dict: """Get either prot or drug data.""" # MEME comment to see if test difference works graph = all_data[which].loc[id, "data"] graph["id"] = id if ( which == "drugs" and "drugs" in kwargs["snakemake"] and kwargs["snakemake"]["drugs"]["node_feats"] == "IUPAC" ): graph["IUPAC"] = all_data[which].loc[id, "IUPAC"] return {which.rstrip("s") + "_" + k: v for k, v in graph.items()} @property def processed_file_names(self) -> Iterable[str]: """Files that are created.""" return [k + ".pt" for k in self.splits.keys()]
[docs] def process(self): """If the dataset was not seen before, process everything.""" with open(self.filename, "rb") as file: all_data = pickle.load(file) self.config = {} self.config["snakemake"] = all_data["config"] for split in self.splits.keys(): data_list = [] for i in all_data["data"]: if i["split"] != split: continue data = self._get_datum(all_data, i["prot_id"], "prots", **self.config) data.update(self._get_datum(all_data, i["drug_id"], "drugs", **self.config)) data["label"] = i["label"] two_graph_data = TwoGraphData(**data) two_graph_data.num_nodes = 1 # supresses the warning data_list.append(two_graph_data) if data_list: self.process_(data_list, split)
[docs]class PreTrainDataset(InMemoryDataset): """Dataset class for pre-training. Args: filename (str): Pickle file that stores the data split (str, optional): Split type ('train', 'val', 'test). Defaults to "train". transform (Callable, optional): transformer to apply on each access. Defaults to None. pre_transform (Callable, optional): pre-transformer to apply once before. Defaults to None. """ def __init__(self, filename: str, transform: Callable = None, pre_transform: Callable = None): basefilename = os.path.basename(filename) basefilename = os.path.splitext(basefilename)[0] root = os.path.join("data", basefilename) self.filename = filename super().__init__(root, transform, pre_transform) self.data, self.slices, self.config = torch.load(self.processed_paths[0])
[docs] def index(self, id: str): """Find protein by id.""" return self[self.data.id.index(id)]
@property def processed_file_names(self) -> Iterable[str]: """Which files have to be in the dir to consider dataset processed. Returns: Iterable[str]: list of files """ return ["data.pt"]
[docs] def process(self): """If the dataset was not seen before, process everything.""" data = pd.read_pickle(self.filename) data_list = [] for i in data["data"]: data_list.append(Data(**i)) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] self.config = data["config"] data, slices = self.collate(data_list) torch.save((data, slices, self.config), self.processed_paths[0])