nn.MultiScaleDeformableAttention

class horizon_plugin_pytorch.nn.MultiScaleDeformableAttention (embed_dims: int = 256, num_heads: int = 8, num_levels: int = 4, num_points: int = 4, im2col_step: int = 64, dropout: float = 0.1, batch_first: bool = False, value_proj_ratio: float = 1.0, split_weight_mul: bool = False, split_batch: bool = False)

Parameters:

embed_dims (int) – The embedding dimension of Attention. Default: 256.

num_heads (int) – Parallel attention heads. Default: 8.

num_levels (int) – The number of feature map used in Attention. Default: 4.

num_points (int) – The number of sampling points for each query in each head. Default: 4.

im2col_step (int) – The step used in image_to_column. Default: 64.

dropout (float) – A Dropout layer on inp_identity. Default: 0.1.

batch_first (bool) – Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False.

value_proj_ratio (float) – The expansion ratio of value_proj. Default: 1.0.

split_weight_mul (bool) – Whether split attention weight mul onto each level outputs. Enable this can reduce memory usage in qat training.

split_batch (bool) – Whether Compute each batch at a time. Enable this can reduce memory usage in qat training.

forward (query: Tensor, key: Tensor | None = None, value: Tensor | None = None, identity: Tensor | None = None, query_pos: Tensor | None = None, key_padding_mask: Tensor | None = None, reference_points: Tensor | None = None, spatial_shapes: Tensor | None = None)

Parameters:

query (Tensor) – Query of Transformer with shape (num_query, bs, embed_dims).

key (Optional[Tensor]) – The key tensor with shape (num_key, bs, embed_dims).

value (Optional[Tensor]) – The value tensor with shape (num_key, bs, embed_dims).

identity (Optional[Tensor]) – The tensor used for addition, with the same shape as query. Default None. If None, query will be used.

query_pos (Optional[Tensor]) – The positional encoding for query. Default: None.

key_padding_mask (Optional[Tensor]) – ByteTensor for query, with shape [bs, num_key].

reference_points (Optional[Tensor]) – The normalized reference points with shape (bs, num_query, num_levels, 2), all elements is range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area. or (bs, num_query, num_levels, 4), add additional two dimensions is (w, h) to form reference boxes.

spatial_shapes (Optional[Tensor]) – Spatial shape of features in different levels. int tensor with shape (num_levels, 2), last dimension represents (h, w).

Returns: the same shape with query.

Return type: Tensor