Skip to content

mlp

Activation

Bases: Enum

Enum for activation functions.

MLP

Bases: Module

Multi-Layer Perceptron (MLP) module.

This module consists of two linear layers with an activation function in between. It supports various configurations such as the hidden size, activation function, initializing the output to zero, and recomputing the forward pass during backpropagation.

Parameters:

Name Type Description Default
size int

The input and output size of the MLP.

required
hidden_size int

The size of the hidden layer.

required
activation Activation | str

The activation function to use. Can be either an Activation enum or a string representing the activation name.

required
device device | None

The device to use for the linear layers.

required
dtype dtype | None

The data type to use for the linear layers.

required
initialize_output_to_zero bool

Whether to initialize the output layer weights to zero. Default is False.

False
recompute bool

Whether to recompute the forward pass during backpropagation. This can save memory but increase computation time. Default is False.

False

Attributes:

Name Type Description
linear1 Linear

The first linear layer.

linear2 Linear

The second linear layer.

activation Activation

The activation function to use.

Example

mlp = MLP(size=128, hidden_size=256, activation='gelu', device='cuda') x = torch.randn(32, 128, device='cuda', dtype=torch.float32) output = mlp(x)

forward

forward(
    x: Tensor,
    *,
    add_input: bool = False,
    allow_inplace: bool = False,
    save_peak_mem_factor: int | None = None
) -> Tensor

Performs the forward pass of the MLP.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
add_input bool

Whether to add input to the output. Default is False.

False
allow_inplace bool

Indicates that 'x' is not used after the call and its buffer can be reused for the output. The operation is not guaranteed to be inplace. Default is False.

False
save_peak_mem_factor int | None

If provided, enables a memory-saving technique that reduces peak memory usage during the forward pass. This requires 'add_input' and 'allow_inplace' to be True. See the documentation of the decorator 'support_save_peak_mem_factor' for details. Default is None.

None