Skip to content

multi_head_attention

MultiHeadAttention

Bases: Module

forward

forward(
    x: Tensor,
    x_kv: Tensor | None = None,
    *,
    cache_kv: bool = False,
    add_input: bool = False,
    allow_inplace: bool = False,
    save_peak_mem_factor: int | None = None,
    reuse_first_head_kv: bool = False,
    only_cache_first_head_kv: bool = False,
    use_cached_kv: bool = False,
    use_second_set_of_queries: bool = False
)

X is the current hidden and has a shape of [batch, ..., seq_len, input_size]. If keys and values are present in the cache and 'freeze_kv' is not set, they are obtained from there and 'x_kv' has to be None. Else, if 'x_kv' is not None, keys and values are obtained by applying the respective linear transformations to 'x_kv'. Else, keys and values are attained by applying the respective linear transformations to 'x' (self attention).