metadata
license: mit
datasets:
- universeTBD/arxiv-astro-abstracts-all
language:
- en
metrics:
- perplexity
pipeline_tag: text-generation
tags:
- llama-2
- astronomy
- astrophysics
- arxiv
AstroLLaMA
Loading the model
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="universeTBD/astrollama"
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="universeTBD/astrollama",
device_map="auto",
)
Generating text from a prompt
import torch
from transformers import pipeline
generator = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# Taken from https://arxiv.org/abs/2308.12823
prompt = "In this letter, we report the discovery of the highest redshift, " \
"heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \
"mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. "
# For reproducibility
torch.manual_seed(42)
generated_text = generator(
prompt,
do_sample=True,
max_length=512
)
Embedding text with AstroLLaMA
texts = [
"Abstract 1",
"Abstract 2"
]
inputs = tokenizer(
texts,
return_tensors="pt",
return_token_type_ids=False,
padding=True,
truncation=True,
max_length=4096
)
inputs.to(model.device)
outputs = model(**inputs, output_hidden_states=True)
# Last layer of the hidden states. Get average embedding of all tokens
embeddings = outputs["hidden_states"][-1][:, :, ...].mean().detach().cpu().numpy()