transformer ¶
LayerStack ¶
Bases: Module
Similar to nn.Sequential, but with support for passing keyword arguments to layers and stacks the same layer multiple times.
PerFeatureTransformer ¶
Bases: Module
A Transformer model processes a token per feature and sample.
This model extends the standard Transformer architecture to operate on a per-feature basis. It allows for processing each feature separately while still leveraging the power of self-attention.
The model consists of an encoder, decoder, and optional components such as a feature positional embedding and a separate decoder for each feature.
forward ¶
Performs a forward pass through the model.
This method supports multiple calling conventions:
model((x,y), **kwargs)
model(train_x, train_y, test_x, **kwargs)
model((style,x,y), **kwargs)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_x |
torch.Tensor | None The input data for the training set. |
required | |
train_y |
torch.Tensor | None The target data for the training set. |
required | |
test_x |
torch.Tensor | None The input data for the test set. |
required | |
x |
torch.Tensor The input data. |
required | |
y |
torch.Tensor | None The target data. |
required | |
style |
torch.Tensor | None The style vector. |
required | |
single_eval_pos |
int The position to evaluate at. |
required | |
only_return_standard_out |
bool Whether to only return the standard output. |
required | |
data_dags |
Any The data DAGs for each example. |
required | |
categorical_inds |
list[int] The indices of categorical features. |
required | |
freeze_kv |
bool Whether to freeze the key and value weights. |
required |
Returns:
Type | Description |
---|---|
dict[str, Tensor]
|
The output of the model, which can be a tensor or a dictionary of tensors. |
reset_save_peak_mem_factor ¶
Sets the save_peak_mem_factor for all layers.
This factor controls how much memory is saved during the forward pass in inference mode.
Setting this factor > 1 will cause the model to save more memory during the forward pass in inference mode.
A value of 8 is good for a 4x larger width in the fully-connected layers.
and yields a situation were we need around
2*num_features*num_items*emsize*2
bytes of memory
for a forward pass (using mixed precision).
WARNING: It should only be used with post-norm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
factor |
int | None
|
The save_peak_mem_factor to set. Recommended value is 8. |
None
|
SerializableGenerator ¶
Bases: Generator
A serializable version of the torch.Generator, that cna be saved and pickled.