Source code for neupi.models.nn
from typing import List, Optional
import torch
import torch.nn as nn
from neupi.core.embedding import Embedding
from neupi.core.model import BaseNNModel
from neupi.registry import register
[docs]
@register("nn_model")
class MLP(BaseNNModel):
"""
A flexible Multi-Layer Perceptron (MLP) network.
This module creates a fully connected neural network with configurable
hidden layers, activation functions, batch normalization, and dropout. We extract the input size from the embedding.
Args:
hidden_sizes (List[int]): A list where each element is the number of
neurons in a hidden layer.
output_size (int): The number of neurons in the output layer.
hidden_activation (str): The activation function for hidden layers.
Supported: 'relu', 'leaky_relu'.
use_batchnorm (bool): If True, adds a BatchNorm1d layer after each
hidden activation. Defaults to True.
dropout_rate (float): The dropout probability. If 0.0, no dropout is
applied. Defaults to 0.0.
"""
def __init__(
self,
hidden_sizes: List[int],
output_size: int,
embedding: Embedding = None,
hidden_activation: str = "relu",
use_batchnorm: bool = True,
dropout_rate: float = 0.0,
):
super().__init__()
self.embedding = embedding
if hidden_activation == "relu":
activation_fn = nn.ReLU
elif hidden_activation == "leaky_relu":
activation_fn = nn.LeakyReLU
else:
raise ValueError("Unsupported activation function.")
layers = []
current_size = embedding.embedding_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(current_size, hidden_size))
if use_batchnorm:
layers.append(nn.BatchNorm1d(hidden_size))
layers.append(activation_fn())
if dropout_rate > 0:
layers.append(nn.Dropout(dropout_rate))
current_size = hidden_size
self.hidden_layers = nn.Sequential(*layers)
self.output_layer = nn.Linear(current_size, output_size)
self.initialize_weights(hidden_activation)
[docs]
def initialize_weights(self, nonlinearity: str = "relu"):
"""
Initializes the weights of the network using appropriate methods.
"""
for layer in self.hidden_layers:
if isinstance(layer, nn.Linear):
nn.init.kaiming_normal_(
layer.weight,
mode="fan_in",
nonlinearity=nonlinearity,
)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
nn.init.xavier_uniform_(self.output_layer.weight)
if self.output_layer.bias is not None:
nn.init.zeros_(self.output_layer.bias)
[docs]
def forward(
self,
evidence_data: torch.Tensor,
evidence_mask: torch.Tensor,
query_mask: torch.Tensor,
unobs_mask: torch.Tensor,
) -> torch.Tensor:
"""
Defines the forward pass of the MLP.
"""
if self.embedding is not None:
x = self.embedding(evidence_data, evidence_mask, query_mask, unobs_mask)
else:
raise ValueError(
"No embedding provided; this is not recommended. If you want to use an identity embedding, use the `IdentityEmbedding` class."
)
x = self.hidden_layers(x)
output = self.output_layer(x)
return output