Source code for rindti.layers.graphconv.transformer
from argparse import ArgumentParser
from torch import Tensor, nn
from torch_geometric.nn import TransformerConv
from torch_geometric.typing import Adj
from ..base_layer import BaseLayer
[docs]class TransformerNet(BaseLayer):
"""Transformer Network.
Refer to :class:`torch_geometric.nn.conv.TransformerConv` for more details.
Args:
input_dim (int): Size of the input vector
output_dim (int): Size of the output vector
hidden_dim (int, optional): Size of the hidden vector. Defaults to 32.
dropout (float, optional): Dropout probability. Defaults to 0.1.
edge_dim (int, optional): Size of the edge vector. Defaults to None.
edge_type (int, optional): Number of edge types. Defaults to "none.
heads (int, optional): Number of heads. Defaults to 1.
num_layers (int, optional): Number of layers. Defaults to 3.
"""
def __init__(
self,
input_dim,
output_dim: int,
hidden_dim: int = 32,
dropout: float = 0.1,
edge_dim: int = None,
edge_type: str = "none",
heads: int = 1,
num_layers: int = 3,
**kwargs,
):
super().__init__()
self.edge_type = edge_type
if edge_type == "none":
edge_dim = None
if edge_type == "label":
self.edge_embed = nn.Embedding(edge_dim + 1, edge_dim)
self.inp = TransformerConv(
input_dim,
hidden_dim,
heads=heads,
dropout=dropout,
edge_dim=edge_dim,
concat=False,
)
self.mid_layers = nn.ModuleList(
[
TransformerConv(
hidden_dim,
hidden_dim,
heads=heads,
dropout=dropout,
edge_dim=edge_dim,
concat=False,
)
for _ in range(num_layers - 2)
]
)
self.out = TransformerConv(hidden_dim, output_dim, heads=1, dropout=dropout, edge_dim=edge_dim, concat=False)
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_feats: Tensor = None, **kwargs) -> Tensor:
""""""
if self.edge_type == "none":
edge_feats = None
elif self.edge_type == "label":
edge_feats = self.edge_embed(edge_feats)
x = self.inp(x, edge_index, edge_feats)
for module in self.mid_layers:
x = module(x, edge_index, edge_feats)
x = self.out(x, edge_index, edge_feats)
return x