Skip to content

src.jormungandr.embedder

Classes:

DetrLearnedPositionEmbedding

DetrLearnedPositionEmbedding(embedding_dim=256)

Bases: Module, Embedder

This module learns positional embeddings up to a fixed maximum size.

Source code in src/jormungandr/embedder.py
96
97
98
99
def __init__(self, embedding_dim=256):
    super().__init__()
    self.row_embeddings = nn.Embedding(50, embedding_dim)
    self.column_embeddings = nn.Embedding(50, embedding_dim)

DetrSinePositionEmbedding

DetrSinePositionEmbedding(num_position_features: int = 128, temperature: int = 10000, normalize: bool = True, scale: float | None = None)

Bases: Module, Embedder

This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images.

Methods:

Source code in src/jormungandr/embedder.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    num_position_features: int = 128,
    temperature: int = 10000,
    normalize: bool = True,
    scale: float | None = None,
):
    super().__init__()
    if scale is not None and normalize is False:
        raise ValueError("normalize should be True if scale is passed")
    self.num_position_features = num_position_features
    self.temperature = temperature
    self.normalize = normalize
    self.scale = 2 * math.pi if scale is None else scale

forward

forward(shape: Size, device: device | str, dtype: dtype, mask: Tensor | None = None) -> torch.Tensor

Parameters:

  • shape

    (Size) –

    The shape of the feature maps for which to compute the position embedding, expected to be (batch_size, channels, height, width)

  • device

    (device | str) –

    The device on which to create the position embedding

  • dtype

    (dtype) –

    The dtype of the position embedding

  • mask

    (Tensor | None, default: None ) –

    An optional mask tensor of shape (batch_size, height, width) where True values indicate masked positions. If None, no positions are masked.

Returns: A position embedding tensor of shape (batch_size, sequence_length, hidden_size) where sequence_length is height * width and hidden_size is num_position_features * 2 (for sine and cosine components)

Source code in src/jormungandr/embedder.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def forward(
    self,
    shape: torch.Size,
    device: torch.device | str,
    dtype: torch.dtype,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Args:
        shape: The shape of the feature maps for which to compute the position embedding, expected to be (batch_size, channels, height, width)
        device: The device on which to create the position embedding
        dtype: The dtype of the position embedding
        mask: An optional mask tensor of shape (batch_size, height, width) where True values indicate masked positions. If None, no positions are masked.
    Returns:
        A position embedding tensor of shape (batch_size, sequence_length, hidden_size) where sequence_length is height * width and hidden_size is num_position_features * 2 (for sine and cosine components)
    """
    if mask is None:
        mask = torch.zeros(
            (shape[0], shape[2], shape[3]), device=device, dtype=torch.bool
        )
    y_embed = mask.cumsum(1, dtype=dtype)
    x_embed = mask.cumsum(2, dtype=dtype)
    if self.normalize:
        eps = 1e-6
        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

    dim_t = torch.arange(
        self.num_position_features, dtype=torch.int64, device=device
    ).to(dtype)
    dim_t = self.temperature ** (
        2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features
    )

    pos_x = x_embed[:, :, :, None] / dim_t
    pos_y = y_embed[:, :, :, None] / dim_t
    pos_x = torch.stack(
        (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
    ).flatten(3)
    pos_y = torch.stack(
        (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
    ).flatten(3)
    pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
    # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
    # expected by the encoder
    pos = pos.flatten(2).permute(0, 2, 1)
    return pos