multi_head_attention ¶
MultiHeadAttention ¶
Bases: Module
forward ¶
forward(
x: Tensor,
x_kv: Tensor | None = None,
*,
cache_kv: bool = False,
add_input: bool = False,
allow_inplace: bool = False,
save_peak_mem_factor: int | None = None,
reuse_first_head_kv: bool = False,
only_cache_first_head_kv: bool = False,
use_cached_kv: bool = False,
use_second_set_of_queries: bool = False
)
X is the current hidden and has a shape of [batch, ..., seq_len, input_size]. If keys and values are present in the cache and 'freeze_kv' is not set, they are obtained from there and 'x_kv' has to be None. Else, if 'x_kv' is not None, keys and values are obtained by applying the respective linear transformations to 'x_kv'. Else, keys and values are attained by applying the respective linear transformations to 'x' (self attention).