File size: 7,438 Bytes
7120d34 96826b5 34e8b2d f760abe 7120d34 96826b5 f760abe 96826b5 f589b80 96826b5 f589b80 a1de4c2 96826b5 f589b80 a1de4c2 96826b5 a1de4c2 f760abe 96826b5 a1de4c2 96826b5 a1de4c2 f760abe 96826b5 a1de4c2 f760abe a1de4c2 f760abe a1de4c2 f760abe a1de4c2 f760abe a1de4c2 f760abe a1de4c2 f760abe f589b80 ff54988 f589b80 f760abe f589b80 4bff232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
---
tags:
- astronomy
- multimodal
- classification
datasets:
- AstroMLCore/AstroM3Processed
- AstroMLCore/AstroM3Dataset
---
AstroM³ is a self-supervised multimodal model for astronomy that integrates time-series photometry, spectra, and metadata into a unified embedding space
for classification and other downstream tasks. AstroM³ is trained on [AstroM3Processed](https://huggingface.co/datasets/AstroMLCore/AstroM3Processed),
which is the pre-processed version of [AstroM3Dataset](https://huggingface.co/datasets/AstroMLCore/AstroM3Dataset).
For more details on the AstroM³ architecture, training, and results, please refer to the [paper](https://arxiv.org/abs/2411.08842).
<p align="center">
<img src="astroclip-architecture.png" width="100%">
<br />
<span>
Figure 1: Overview of the multimodal CLIP framework adapted for astronomy, incorporating three data modalities: photometric time-series, spectra, and metadata.
Each modality is processed by a dedicated encoder to create embeddings, which are then mapped into a shared embedding space through projection heads.
Pairwise similarity matrices align the embeddings across modalities, and a symmetric cross-entropy loss, computed over these matrices, optimizes the model.
The total loss, derived from all pairwise losses, guides the model’s trimodal learning.
</span>
</p>
To use AstroM³ for inference, install the AstroM3 library from our [GitHub repo](https://github.com/MeriDK/AstroM3).
```sh
git clone https://github.com/MeriDK/AstroM3.git
cd AstroM3
```
Create a virtual environment (tested with Python 3.10.14), then install the required dependencies:
```sh
uv venv venv --python 3.10.14
source venv/bin/activate
uv pip install -r requirements.txt
```
## A simple example to get started
1. Data Loading & Preprocessing
```python
from datasets import load_dataset
from src.data import process_photometry
# Load the test dataset
test_dataset = load_dataset('AstroMLCore/AstroM3Processed', name='full_42', split='test')
# Process photometry to have a fixed sequence length of 200 (center-cropped)
test_dataset = test_dataset.map(process_photometry, batched=True, fn_kwargs={'seq_len': 200, 'how': 'center'})
test_dataset = test_dataset.with_format('torch')
```
2. Model Loading & Embedding Extraction
```python
import torch
from src.model import AstroM3
# Load the base AstroM3-CLIP model
model = AstroM3.from_pretrained('AstroMLCore/AstroM3-CLIP')
# Retrieve the first sample (batch size = 1)
sample = test_dataset[0:1]
photometry = sample['photometry']
photometry_mask = sample['photometry_mask']
spectra = sample['spectra']
metadata = sample['metadata']
# Example 1: Generate embeddings when all modalities are present
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
multimodal_emb = (p_emb + s_emb + m_emb) / 3
print('Multimodal Embedding (All Modalities):', multimodal_emb)
# Example 2: Generate embeddings when the spectra modality is missing
dummy_spectra = torch.zeros_like(spectra) # Dummy tensor for missing spectra
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, dummy_spectra, metadata)
multimodal_emb_missing = (p_emb + m_emb) / 2
print('Multimodal Embedding (Spectra Missing):', multimodal_emb_missing)
```
3. Classification Examples
```python
from src.model import AstroM3, Informer, GalSpecNet, MetaModel
# Photometry classification
photo_model = Informer.from_pretrained('AstroMLCore/AstroM3-CLIP-photo')
prediction = photo_model(photometry, photometry_mask).argmax(dim=1).item()
print('Photometry Classification:', test_dataset.features['label'].int2str(prediction))
# Spectra classification
spectra_model = GalSpecNet.from_pretrained('AstroMLCore/AstroM3-CLIP-spectra')
prediction = spectra_model(spectra).argmax(dim=1).item()
print('Spectra Classification:', test_dataset.features['label'].int2str(prediction))
# Metadata classification
meta_model = MetaModel.from_pretrained('AstroMLCore/AstroM3-CLIP-meta')
prediction = meta_model(metadata).argmax(dim=1).item()
print('Metadata Classification:', test_dataset.features['label'].int2str(prediction))
# Multimodal classification
all_model = AstroM3.from_pretrained('AstroMLCore/AstroM3-CLIP-all')
prediction = all_model(photometry, photometry_mask, spectra, metadata).argmax(dim=1).item()
print('Multimodal Classification:', test_dataset.features['label'].int2str(prediction))
```
## The AstroM³ Family
| # Model | # Description |
| :--- | :--- |
| [AstroM3-CLIP](https://huggingface.co/AstroMLCore/AstroM3-CLIP) | The base model pre-trained using the trimodal CLIP approach. |
| [AstroM3-CLIP-meta](https://huggingface.co/AstroMLCore/AstroM3-CLIP-meta) | Fine-tuned for metadata-only classification. |
| [AstroM3-CLIP-spectra](https://huggingface.co/AstroMLCore/AstroM3-CLIP-spectra) | Fine-tuned for spectra-only classification. |
| [AstroM3-CLIP-photo](https://huggingface.co/AstroMLCore/AstroM3-CLIP-photo) | Fine-tuned for photometry-only classification. |
| [AstroM3-CLIP-all](https://huggingface.co/AstroMLCore/AstroM3-CLIP-all) | Fine-tuned for multimodal classification. |
## AstroM3-CLIP Variants
These variants of the base AstroM3-CLIP model are trained using different random seeds (42, 0, 66, 12, 123);
ensure that the dataset is loaded with the corresponding seed for consistency.
| # Model | # Description |
| :--- | :--- |
| [AstroM3-CLIP-42](https://huggingface.co/AstroMLCore/AstroM3-CLIP-42) | The base model pre-trained with random seed 42 (identical to AstroM3-CLIP). |
| [AstroM3-CLIP-0](https://huggingface.co/AstroMLCore/AstroM3-CLIP-0) | AstroM3-CLIP pre-trained with random seed 0 (use dataset with seed 0). |
| [AstroM3-CLIP-66](https://huggingface.co/AstroMLCore/AstroM3-CLIP-66) | AstroM3-CLIP pre-trained with random seed 66 (use dataset with seed 66). |
| [AstroM3-CLIP-12](https://huggingface.co/AstroMLCore/AstroM3-CLIP-12) | AstroM3-CLIP pre-trained with random seed 12 (use dataset with seed 12). |
| [AstroM3-CLIP-123](https://huggingface.co/AstroMLCore/AstroM3-CLIP-123) | AstroM3-CLIP pre-trained with random seed 123 (use dataset with seed 123). |
## Using your own data
Note that the data in the AstroM3Processed dataset is already pre-processed.
If you want to use the model with your own data, you must pre-process it in the same way:
1. **Spectra**: Each spectrum is interpolated to a fixed wavelength grid (3850–9000 Å), normalized using mean and MAD, and log-MAD is added as an auxiliary feature.
2. **Photometry**: Light curves are deduplicated, sorted by time, normalized using mean and MAD, time-scaled to [0, 1], and augmented with auxiliary features like log-MAD and time span.
3. **Metadata**: Scalar metadata is transformed via domain-specific functions (e.g., absolute magnitude, log, sin/cos), then normalized using dataset-level statistics.
For a detailed description, read the [paper](https://arxiv.org/abs/2411.08842).
To see exactly how we performed this preprocessing, refer to [`preprocess.py`](https://huggingface.co/datasets/AstroMLCore/AstroM3Dataset/blob/main/preprocess.py) in the AstroM3Dataset repo.
---
## Citation
🤗 If you find this model usefull, please cite our paper 🤗
```bibtex
@article{rizhko2024astrom,
title={AstroM $\^{} 3$: A self-supervised multimodal model for astronomy},
author={Rizhko, Mariia and Bloom, Joshua S},
journal={arXiv preprint arXiv:2411.08842},
year={2024}
}
``` |