rindti.data¶
- class TwoGraphData(**kwargs)[source]¶
Subclass of torch_geometric.data.Data for protein and drug data.
- class DTIDataModule(filename: str, exp_name: str, batch_size: int = 128, num_workers: int = 1, shuffle: bool = True)[source]¶
Data module for the DTI dataset.
- class DTIDataset(filename: str, exp_name: str, split: str = 'train', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)[source]¶
Dataset class for prots and drugs.
- Parameters
filename (str) – Pickle file that stores the data
split (str, optional) – Split type (‘train’, ‘val’, ‘test). Defaults to “train”.
transform (Callable, optional) – transformer to apply on each access. Defaults to None.
pre_transform (Callable, optional) – pre-transformer to apply once before. Defaults to None.
- class PreTrainDataset(filename: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None)[source]¶
Dataset class for pre-training.
- Parameters
filename (str) – Pickle file that stores the data
split (str, optional) – Split type (‘train’, ‘val’, ‘test). Defaults to “train”.
transform (Callable, optional) – transformer to apply on each access. Defaults to None.
pre_transform (Callable, optional) – pre-transformer to apply once before. Defaults to None.
- class DataCorruptor(frac: Dict[str, float], type: str = 'mask')[source]¶
Corrupt or mask the nodes in a graph (or graph pair).
- class SizeFilter(min_nnodes: int, max_nnodes: int = 0)[source]¶
Filters out graph that are too big/small.
- corrupt_features(features: torch.Tensor, frac: float) Tuple[torch.Tensor, list][source]¶
Return corrupt features.
- Parameters
features (torch.Tensor) – Node features
frac (float) – Fraction of nodes to corrupt
- Returns
New corrupt features, idx of masked nodes
- Return type
- mask_features(features: torch.Tensor, frac: float) Tuple[torch.Tensor, list][source]¶
Return masked features.
- Parameters
features (torch.Tensor) – Node features
frac (float) – Fraction of nodes to mask
- Returns
New masked features, idx of masked nodes
- Return type