mlp ¶
Activation ¶
Bases: Enum
Enum for activation functions.
MLP ¶
Bases: Module
Multi-Layer Perceptron (MLP) module.
This module consists of two linear layers with an activation function in between. It supports various configurations such as the hidden size, activation function, initializing the output to zero, and recomputing the forward pass during backpropagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
size |
int
|
The input and output size of the MLP. |
required |
hidden_size |
int
|
The size of the hidden layer. |
required |
activation |
Activation | str
|
The activation function to use. Can be either an Activation enum or a string representing the activation name. |
required |
device |
device | None
|
The device to use for the linear layers. |
required |
dtype |
dtype | None
|
The data type to use for the linear layers. |
required |
initialize_output_to_zero |
bool
|
Whether to initialize the output layer weights to zero. Default is False. |
False
|
recompute |
bool
|
Whether to recompute the forward pass during backpropagation. This can save memory but increase computation time. Default is False. |
False
|
Attributes:
Name | Type | Description |
---|---|---|
linear1 |
Linear
|
The first linear layer. |
linear2 |
Linear
|
The second linear layer. |
activation |
Activation
|
The activation function to use. |
Example
mlp = MLP(size=128, hidden_size=256, activation='gelu', device='cuda') x = torch.randn(32, 128, device='cuda', dtype=torch.float32) output = mlp(x)
forward ¶
forward(
x: Tensor,
*,
add_input: bool = False,
allow_inplace: bool = False,
save_peak_mem_factor: int | None = None
) -> Tensor
Performs the forward pass of the MLP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor
|
The input tensor. |
required |
add_input |
bool
|
Whether to add input to the output. Default is False. |
False
|
allow_inplace |
bool
|
Indicates that 'x' is not used after the call and its buffer can be reused for the output. The operation is not guaranteed to be inplace. Default is False. |
False
|
save_peak_mem_factor |
int | None
|
If provided, enables a memory-saving technique that reduces peak memory usage during the forward pass. This requires 'add_input' and 'allow_inplace' to be True. See the documentation of the decorator 'support_save_peak_mem_factor' for details. Default is None. |
None
|