from argparse import ArgumentParser
from math import ceil
import torch
import torch.nn.functional as F
import torch_geometric
from torch.functional import Tensor
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool, dense_mincut_pool
from torch_geometric.typing import Adj
from ..base_layer import BaseLayer
[docs]class DiffPoolNet(BaseLayer):
"""Differential Pooling module.
Refer to :class:`torch_geometric.nn.dense.dense_diff_pool` and :class:`torch_geometric.nn.dense.dense_mincut_pool` for more details.
Args:
input_dim (int): Size of the input vector
output_dim (int): Size of the output vector
hidden_dim (int, optional): Size of the hidden vector. Defaults to 32.
max_nodes (int, optional): Maximal number of nodes in a graph. Defaults to 600.
dropout (float, optional): Dropout ratio. Defaults to 0.2.
ratio (float, optional): Pooling ratio. Defaults to 0.25.
pooling_method (str, optional): Type of pooling. Defaults to "mincut".
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 128,
max_nodes: int = 600,
dropout: float = 0.2,
ratio: float = 0.25,
pooling_method: str = "mincut",
**kwargs,
):
super().__init__()
self.max_nodes = ceil(max_nodes * 1.2)
self.dropout = dropout
self.pool = {
"diffpool": dense_diff_pool,
"mincut": dense_mincut_pool,
}[pooling_method]
num_nodes = ceil(self.max_nodes * ratio)
self.poolblock1 = DiffPoolBlock(input_dim, num_nodes)
self.embedblock1 = DiffPoolBlock(input_dim, hidden_dim)
num_nodes = ceil(num_nodes * ratio)
self.poolblock2 = DiffPoolBlock(hidden_dim, num_nodes)
self.embedblock2 = DiffPoolBlock(hidden_dim, hidden_dim)
self.embedblock3 = DiffPoolBlock(hidden_dim, hidden_dim)
self.lin1 = torch.nn.Linear(hidden_dim, output_dim)
[docs] def forward(self, x: Tensor, edge_index: Adj, batch: Tensor, **kwargs) -> Tensor:
""""""
x, _ = torch_geometric.utils.to_dense_batch(x, batch, max_num_nodes=self.max_nodes)
adj = torch_geometric.utils.to_dense_adj(edge_index, batch, max_num_nodes=self.max_nodes)
s = self.poolblock1(x, adj) # (256, 140, 75)
x = self.embedblock1(x, adj) # (256, 140, 96)
x, adj, lp_loss1, e_loss1 = self.pool(x, adj, s)
s = self.poolblock2(x, adj) # (256, 70, 35)
x = self.embedblock2(x, adj) # (256, 70, 96)
x, adj, lp_loss2, e_loss2 = self.pool(x, adj, s)
x = self.embedblock3(x, adj) # (256, 35, 96)
x = F.relu(x)
x = x.mean(dim=1)
x = self.lin1(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.normalize(x, dim=1)
return x
class DiffPoolBlock(torch.nn.Module):
"""Block of DiffPool."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.conv1 = DenseSAGEConv(in_channels, out_channels)
self.bn1 = torch.nn.BatchNorm1d(out_channels)
def bn(self, i: int, x: Tensor) -> Tensor:
"""Apply batch normalisation.
Args:
i (int): layer idx
x (Tensor): Node features
Returns:
Tensor: Updated node features
"""
batch_size, num_nodes, num_channels = x.size()
x = x.view(-1, num_channels)
x = getattr(self, "bn{}".format(i))(x)
x = x.view(batch_size, num_nodes, num_channels)
return x
def forward(self, x: Tensor, adj: Adj) -> Tensor:
""""""
x = self.bn(1, F.relu(self.conv1(x, adj)))
return x