Source code for rindti.layers.other.mlp

from argparse import ArgumentParser

from torch import nn
from torch.functional import Tensor

from ..base_layer import BaseLayer


[docs]class MLP(BaseLayer): """Simple Multi-layer perceptron. Refer to :class:`torch.nn.Sequential` for more details. Args: input_dim (int): Size of the input vector output_dim (int): Size of the output vector hidden_dim (int, optional): Size of the hidden vector. Defaults to 32. num_layers (int, optional): Total Number of layers. Defaults to 2. dropout (float, optional): Dropout ratio. Defaults to 0.2. """ def __init__( self, input_dim: int, out_dim: int, hidden_dim: int = 64, num_layers: int = 2, dropout: float = 0.2, **kwargs, ): super().__init__() self.mlp = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)) for i in range(num_layers - 2): self.mlp.add_module("hidden_linear{}".format(i), nn.Linear(hidden_dim, hidden_dim)) self.mlp.add_module("hidden_relu{}".format(i), nn.ReLU()) self.mlp.add_module("hidden_dropout{}".format(i), nn.Dropout(dropout)) self.mlp.add_module("final_linear", nn.Linear(hidden_dim, out_dim))
[docs] def forward(self, x: Tensor) -> Tensor: return self.mlp(x)