Source code for rindti.data.transforms

import pickle
from copy import deepcopy
from math import ceil
from typing import Dict, Tuple, Union

import numpy as np
import torch
from torch_geometric.data import Data

from .data import TwoGraphData


[docs]class SizeFilter: """Filters out graph that are too big/small.""" def __init__(self, min_nnodes: int, max_nnodes: int = 0): self.min_nnodes = min_nnodes self.max_nnodes = max_nnodes def __call__(self, data: Data) -> bool: """Returns True if number of nodes in given graph is within required values else False.""" nnodes = data.num_nodes return nnodes > self.min_nnodes and nnodes < self.max_nnodes
[docs]class DataCorruptor: """Corrupt or mask the nodes in a graph (or graph pair). Args: frac (Dict[str, float]): dict of which attributes to corrupt ({'x' : 0.05} or {'prot_x' : 0.1, 'drug_x' : 0.2}) type (str, optional): 'corrupt' or 'mask'. Corrupt puts new values sampled from old, mask puts zeroes. Defaults to 'mask'. """ def __init__(self, frac: Dict[str, float], type: str = "mask"): self.type = type self.frac = {k: v for k, v in frac.items() if v > 0} self._set_corr_func() def _set_corr_func(self): """Sets the necessary corruption function""" if self.type == "mask": self.corr_func = mask_features elif self.type == "corrupt": self.corr_func = corrupt_features def __call__(self, data: Union[Data, TwoGraphData]) -> TwoGraphData: """Apply corruption. Args: orig_data (Union[Data, TwoGraphData]): data, has to have attributes that match ones from self.frac Returns: TwoGraphData: Data with corrupted features """ for k, v in self.frac.items(): new_feat, idx = self.corr_func(data[k], v) data[k + "_orig"] = data[k][idx].detach().clone() data[k + "_idx"] = idx data[k][idx] = new_feat return data
[docs]def corrupt_features(features: torch.Tensor, frac: float) -> Tuple[torch.Tensor, list]: """Return corrupt features. Args: features (torch.Tensor): Node features frac (float): Fraction of nodes to corrupt Returns: torch.Tensor, list: New corrupt features, idx of masked nodes """ assert frac >= 0 and frac <= 1, "frac has to between 0 and 1!" num_nodes = features.size(0) num_corrupt_nodes = ceil(num_nodes * frac) idx = list(np.random.choice(range(num_nodes), num_corrupt_nodes, replace=False)) new = np.random.choice(range(num_nodes), num_corrupt_nodes, replace=False) return features[new], idx
[docs]def mask_features(features: torch.Tensor, frac: float) -> Tuple[torch.Tensor, list]: """Return masked features. Args: features (torch.Tensor): Node features frac (float): Fraction of nodes to mask Returns: torch.Tensor, list: New masked features, idx of masked nodes """ assert frac >= 0 and frac <= 1, "frac has to between 0 and 1!" num_nodes = features.size(0) num_corrupt_nodes = ceil(num_nodes * frac) idx = list(np.random.choice(range(num_nodes), num_corrupt_nodes, replace=False)) features = torch.zeros_like(features[idx]) return features, idx