File size: 4,545 Bytes
05d640e
8ef2cad
05d640e
 
235555c
 
 
 
 
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ef2cad
05d640e
 
 
 
 
 
 
 
 
 
235555c
05d640e
 
8ef2cad
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
8ef2cad
05d640e
8ef2cad
 
05d640e
 
8ef2cad
05d640e
 
 
 
235555c
05d640e
 
8ef2cad
05d640e
8ef2cad
 
 
 
 
 
 
 
 
05d640e
 
 
 
 
8ef2cad
 
05d640e
 
235555c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import math

from typing import List, Tuple, Union

from .layers import mlp

SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]


def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    """
    Applies Fourier feature mapping to input tensor x using frequency matrix w. This
    projects inputs through sinusoidal functions to create higher dimensional features
    that help mitigate spectral bias - the tendency of neural networks to learn
    low-frequency functions more easily than high-frequency ones. By explicitly
    mapping inputs to higher frequencies through sin/cos transformations, we enable
    better learning of fine details and higher frequency patterns.

    Args:
        x: Input tensor to transform
        w: Matrix of frequencies for the Fourier features transformation

    Returns:
        Concatenated cosine and sine transformed features as a tensor
    """
    f = 2 * math.pi * x @ w
    return torch.cat([f.cos(), f.sin()], dim=-1)


def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input a tensor containing a single float coordinate value (x or y)
    and encodes it into hidden states for input to the text model.

    Args:
        coord: Tensor with single float coordinate value

    Returns:
        Encoded hidden states tensor for input to text model
    """
    return w.coord_encoder(fourier_features(coord, w.coord_features))


def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input the last hidden state from the text model and outputs a single logit
    representing either an x or y coordinate prediction.

    Args:
        hidden_state: The final hidden state tensor from the text model.

    Returns:
        A single logit representing the predicted coordinate value (x or y)
    """
    return mlp(hidden_state, w.coord_decoder)


def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes a tensor containing width and height values and encodes them into
    hidden states for input to the text model.

    Args:
        size: Tensor with two floats for width and height

    Returns:
        Encoded hidden states tensor for input to text model
    """
    return w.size_encoder(fourier_features(size, w.size_features))


def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input the last hidden state from the text model and outputs logits
    for 1024 bins representing width and height in log-scale.

    The bins are distributed according to the formula:
    bin = (log2(size) + 10.0) / 10.0 * 1023.0
    where size values are clamped to be at least 1/1024.

    To convert from bin back to size:
    size = 2^((bin / 1023.0) * 10.0 - 10.0)

    Args:
        hidden_state: The final hidden state tensor from the text model.

    Returns:
        A tensor containing logits for 1024 bins for width and height.
        Shape is (2, 1024) where the first dimension corresponds to width and height.
    """
    return mlp(hidden_state, w.size_decoder).view(2, -1)


def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
    """
    Takes a list of spatial references (points or regions) and encodes them into
    hidden states for input to the text model.

    Args:
        spatial_refs: List of spatial references (points or boxes)
            - Points are represented as normalized (x, y) tuples
            - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples

    Returns:
        {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
    """
    coords, sizes = [], []
    for ref in spatial_refs:
        if len(ref) == 2:
            coords.append(ref[0])
            coords.append(ref[1])
        else:
            x_c = (ref[0] + ref[2]) / 2
            y_c = (ref[1] + ref[3]) / 2
            width = ref[2] - ref[0]
            height = ref[3] - ref[1]
            coords.append(x_c)
            coords.append(y_c)
            sizes.append([width, height])

    coords = torch.tensor(
        coords, device=w.coord_features.device, dtype=w.coord_features.dtype
    ).view(-1, 1)
    coords = encode_coordinate(coords, w)

    if sizes:
        sizes = torch.tensor(
            sizes, device=w.size_features.device, dtype=w.size_features.dtype
        )
        sizes = encode_size(sizes, w)
    else:
        sizes = None

    return {"coords": coords, "sizes": sizes}