Source code for rindti.layers.graphconv.ginconv

from argparse import ArgumentParser

from torch import nn
from torch.functional import Tensor
from torch_geometric.nn import GINConv
from torch_geometric.typing import Adj

from ..base_layer import BaseLayer


[docs]class GINConvNet(BaseLayer): """Graph Isomorphism Network. Refer to :class:`torch_geometric.nn.conv.GINConv` for more details. Args: input_dim (int): Size of the input vector output_dim (int): Size of the output vector hidden_dim (int, optional): Size of the hidden vector. Defaults to 32. num_layers (int, optional): Total number of layers. Defaults to 3. """ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 64, num_layers: int = 3, **kwargs): super().__init__() self.inp = GINConv( nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.PReLU(), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), ) ) mid_layers = [ GINConv( nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.PReLU(), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), ) ) for _ in range(num_layers - 2) ] self.mid_layers = nn.ModuleList(mid_layers) self.out = GINConv( nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.PReLU(), nn.Linear(hidden_dim, output_dim), nn.BatchNorm1d(output_dim), ) )
[docs] def forward(self, x: Tensor, edge_index: Adj, **kwargs) -> Tensor: """""" x = self.inp(x, edge_index) for module in self.mid_layers: x = module(x, edge_index) x = self.out(x, edge_index) return x