galileo / README.md
gabrieltseng's picture
Update README.md
b53fc81 verified
metadata
license: mit

Galileo

Learning Global and Local Features in Pretrained Remote Sensing Models

Galileo_diagram

Galileo is a family of pretrained remote sensing models. These models have been pretrained on a diversity of remote sensing inputs, and perform well on a range of benchmark tasks. For more information, please see our paper.

Using Galileo

Galileo can be loaded either from src, or from single_file_galileo.py for easy porting to other codebases:

from single_file_galileo import Encoder as SingleFileEncoder
from src.galileo import Encoder


src_model = Encoder.load_from_folder(DATA_FOLDER / "models/nano")
sf_model = SingleFileEncoder.load_from_folder(
    DATA_FOLDER / "models/nano", device=torch.device("cpu")
)

for model_p, sf_model_p in zip(src_model.parameters(), sf_model.parameters()):
    assert torch.equal(model_p, sf_model_p)

The inputs to Galileo are described in the MaskedOutput:

class MaskedOutput(NamedTuple):
    """
    A mask can take 3 values:
    0: seen by the encoder (i.e. makes the key and value tokens in the decoder)
    1: not seen by the encoder, and ignored by the decoder
    2: not seen by the encoder, and processed by the decoder (the decoder's query values)
    """

    space_time_x: torch.Tensor  # [B, H, W, T, len(SPACE_TIME_BANDS)]
    space_x: torch.Tensor  # [B, H, W, len(SPACE_BANDS)]
    time_x: torch.Tensor  # [B, T, len(TIME_BANDS)]
    static_x: torch.Tensor  # [B, len(STATIC_BANDS)]
    space_time_mask: torch.Tensor  # [B, H, W, T, len(SPACE_TIME_BANDS_GROUPS_IDX)]
    space_mask: torch.Tensor  # [B, H, W, len(SPACE_BAND_GROUPS_IDX)]
    time_mask: torch.Tensor   # [B, T, len(TIME_BAND_GROUPS_IDX)]
    static_mask: torch.Tensor  # [B, len(STATIC_BAND_GROUPS_IDX)]
    months: torch.Tensor  # [B, T]

Each of these bands are described in single_file_galileo.py.

Alternatively, a utility function is provided to transform the bands into MaskedOutput objects. This transformation is for a single instance (i.e. it omits the B dimension above). This function optionally normalizes the data against the Galileo pre-training statistics.

from src.data.utils import S2_BANDS, construct_galileo_input

t, h, w = 2, 4, 4
s2 = torch.randn((t, h, w, len(S2_BANDS)))
masked_output = construct_galileo_input(s2=s2, normalize=normalize)