rindti.data

class TwoGraphData(**kwargs)[source]

Subclass of torch_geometric.data.Data for protein and drug data.

n_nodes(prefix: str) int[source]

Return number of nodes for graph with prefix.

n_edges(prefix: str) int[source]

Return number of edges for graph with prefix.

class DTIDataModule(filename: str, exp_name: str, batch_size: int = 128, num_workers: int = 1, shuffle: bool = True)[source]

Data module for the DTI dataset.

setup(stage: Optional[str] = None)[source]

Load the individual datasets

update_config(config: dict) None[source]

Update the main config with the config of the dataset.

class DTIDataset(filename: str, exp_name: str, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)[source]

Dataset class for prots and drugs.

Parameters
  • filename (str) – Pickle file that stores the data

  • split (str, optional) – Split type (‘train’, ‘val’, ‘test). Defaults to “train”.

  • transform (Callable, optional) – transformer to apply on each access. Defaults to None.

  • pre_transform (Callable, optional) – pre-transformer to apply once before. Defaults to None.

property processed_file_names: Iterable[str]

Files that are created.

process()[source]

If the dataset was not seen before, process everything.

class PreTrainDataModule(*args, **kwargs)[source]

DataModule for pretraining on prots.

setup(stage: Optional[str] = None)[source]

Load the individual datasets.

update_config(config: dict) None[source]

Update the main config with the config of the dataset.

class PreTrainDataset(filename: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None)[source]

Dataset class for pre-training.

Parameters
  • filename (str) – Pickle file that stores the data

  • split (str, optional) – Split type (‘train’, ‘val’, ‘test). Defaults to “train”.

  • transform (Callable, optional) – transformer to apply on each access. Defaults to None.

  • pre_transform (Callable, optional) – pre-transformer to apply once before. Defaults to None.

index(id: str)[source]

Find protein by id.

property processed_file_names: Iterable[str]

Which files have to be in the dir to consider dataset processed.

Returns

list of files

Return type

Iterable[str]

process()[source]

If the dataset was not seen before, process everything.

class DataCorruptor(frac: Dict[str, float], type: str = 'mask')[source]

Corrupt or mask the nodes in a graph (or graph pair).

Parameters
  • frac (Dict[str, float]) – dict of which attributes to corrupt ({‘x’ : 0.05} or {‘prot_x’ : 0.1, ‘drug_x’ : 0.2})

  • type (str, optional) – ‘corrupt’ or ‘mask’. Corrupt puts new values sampled from old, mask puts zeroes. Defaults to ‘mask’.

class SizeFilter(min_nnodes: int, max_nnodes: int = 0)[source]

Filters out graph that are too big/small.

corrupt_features(features: torch.Tensor, frac: float) Tuple[torch.Tensor, list][source]

Return corrupt features.

Parameters
  • features (torch.Tensor) – Node features

  • frac (float) – Fraction of nodes to corrupt

Returns

New corrupt features, idx of masked nodes

Return type

torch.Tensor, list

mask_features(features: torch.Tensor, frac: float) Tuple[torch.Tensor, list][source]

Return masked features.

Parameters
  • features (torch.Tensor) – Node features

  • frac (float) – Fraction of nodes to mask

Returns

New masked features, idx of masked nodes

Return type

torch.Tensor, list