Skip to content

transformer

LayerStack

Bases: Module

Similar to nn.Sequential, but with support for passing keyword arguments to layers and stacks the same layer multiple times.

PerFeatureTransformer

Bases: Module

A Transformer model processes a token per feature and sample.

This model extends the standard Transformer architecture to operate on a per-feature basis. It allows for processing each feature separately while still leveraging the power of self-attention.

The model consists of an encoder, decoder, and optional components such as a feature positional embedding and a separate decoder for each feature.

forward

forward(*args: Any, **kwargs: Any) -> dict[str, Tensor]

Performs a forward pass through the model.

This method supports multiple calling conventions:

  • model((x,y), **kwargs)
  • model(train_x, train_y, test_x, **kwargs)
  • model((style,x,y), **kwargs)

Parameters:

Name Type Description Default
train_x

torch.Tensor | None The input data for the training set.

required
train_y

torch.Tensor | None The target data for the training set.

required
test_x

torch.Tensor | None The input data for the test set.

required
x

torch.Tensor The input data.

required
y

torch.Tensor | None The target data.

required
style

torch.Tensor | None The style vector.

required
single_eval_pos

int The position to evaluate at.

required
only_return_standard_out

bool Whether to only return the standard output.

required
data_dags

Any The data DAGs for each example.

required
categorical_inds

list[int] The indices of categorical features.

required
freeze_kv

bool Whether to freeze the key and value weights.

required

Returns:

Type Description
dict[str, Tensor]

The output of the model, which can be a tensor or a dictionary of tensors.

reset_save_peak_mem_factor

reset_save_peak_mem_factor(
    factor: int | None = None,
) -> None

Sets the save_peak_mem_factor for all layers.

This factor controls how much memory is saved during the forward pass in inference mode.

Setting this factor > 1 will cause the model to save more memory during the forward pass in inference mode.

A value of 8 is good for a 4x larger width in the fully-connected layers. and yields a situation were we need around 2*num_features*num_items*emsize*2 bytes of memory

for a forward pass (using mixed precision).

WARNING: It should only be used with post-norm.

Parameters:

Name Type Description Default
factor int | None

The save_peak_mem_factor to set. Recommended value is 8.

None

SerializableGenerator

Bases: Generator

A serializable version of the torch.Generator, that cna be saved and pickled.