Skip to content

layer

LayerNorm

Bases: LayerNorm

Custom LayerNorm module that supports saving peak memory factor.

This module extends the PyTorch LayerNorm implementation to handle FP16 inputs efficiently and support saving peak memory factor.

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to the base LayerNorm class.

()
**kwargs Any

Keyword arguments passed to the base LayerNorm class.

{}

forward

forward(
    input: Tensor,
    *,
    allow_inplace: bool = False,
    save_peak_mem_factor: int | None = None
) -> Tensor

Perform layer normalization on the input tensor.

Parameters:

Name Type Description Default
input Tensor

The input tensor.

required
allow_inplace bool

Whether to allow in-place operations. Default is False.

False
save_peak_mem_factor int | None

The factor to save peak memory. Default is None.

None

Returns:

Type Description
Tensor

The layer normalized tensor.

PerFeatureEncoderLayer

Bases: Module

Transformer encoder layer that processes each feature block separately.

This layer consists of multi-head attention between features, multi-head attention between items, and feedforward neural networks (MLPs).

It supports various configurations and optimization options.

Parameters:

Name Type Description Default
d_model int

The dimensionality of the input and output embeddings.

required
nhead int

The number of attention heads.

required
dim_feedforward int | None

The dimensionality of the feedforward network. Default is None (2 * d_model).

None
activation str

The activation function to use in the MLPs.

'relu'
layer_norm_eps float

The epsilon value for layer normalization.

1e-05
pre_norm bool

Whether to apply layer normalization before or after the attention and MLPs.

False
device device | None

The device to use for the layer parameters.

None
dtype dtype | None

The data type to use for the layer parameters.

None
recompute_attn bool

Whether to recompute attention during backpropagation.

False
second_mlp bool

Whether to include a second MLP in the layer.

False
layer_norm_with_elementwise_affine bool

Whether to use elementwise affine parameters in layer normalization.

False
zero_init bool

Whether to initialize the output of the MLPs to zero.

False
save_peak_mem_factor int | None

The factor to save peak memory, only effective with post-norm.

None
attention_between_features bool

Whether to apply attention between feature blocks.

True
multiquery_item_attention bool

Whether to use multiquery attention for items.

False
multiquery_item_attention_for_test_set bool

Whether to use multiquery attention for the test set.

False
attention_init_gain float

The gain value for initializing attention parameters.

1.0
d_k int | None

The dimensionality of the query and key vectors. Default is (d_model // nhead).

None
d_v int | None

The dimensionality of the value vectors. Default is (d_model // nhead).

None
precomputed_kv None | Tensor | tuple[Tensor, Tensor]

Precomputed key-value pairs for attention.

None

empty_trainset_representation_cache

empty_trainset_representation_cache() -> None

Empty the trainset representation cache.

forward

forward(
    state: Tensor,
    single_eval_pos: int | None = None,
    *,
    cache_trainset_representation: bool = False,
    att_src: Tensor | None = None
) -> Tensor

Pass the input through the encoder layer.

Parameters:

Name Type Description Default
state Tensor

The transformer state passed as input to the layer of shape (batch_size, num_items, num_feature_blocks, d_model).

required
single_eval_pos int | None

The position from which on everything is treated as test set.

None
cache_trainset_representation bool

Whether to cache the trainset representation. If single_eval_pos is set (> 0 and not None), create a cache of the trainset KV. This may require a lot of memory. Otherwise, use cached KV representations for inference.

False
att_src Tensor | None

The tensor to attend to from the final layer of the encoder. It has a shape of (batch_size, num_train_items, num_feature_blocks, d_model). This does not work with multiquery_item_attention_for_test_set and cache_trainset_representation at this point.

None

Returns:

Type Description
Tensor

The transformer state passed through the encoder layer.