|
--- |
|
license: mit |
|
datasets: |
|
- universeTBD/arxiv-astro-abstracts-all |
|
language: |
|
- en |
|
metrics: |
|
- perplexity |
|
pipeline_tag: text-generation |
|
tags: |
|
- llama-2 |
|
- astronomy |
|
- astrophysics |
|
- arxiv |
|
--- |
|
|
|
<p><h1>AstroLLaMA</h1></p> |
|
|
|
<p align="center"> |
|
<img src="https://huggingface.co/universeTBD/astrollama/resolve/main/images/astrollama-logo.png" alt="AstroLLaMA" width="500px"/> |
|
</p> |
|
|
|
## Loading the model |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path="universeTBD/astrollama" |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
pretrained_model_name_or_path="universeTBD/astrollama", |
|
device_map="auto", |
|
) |
|
``` |
|
|
|
## Generating text from a prompt |
|
|
|
```python |
|
import torch |
|
from transformers import pipeline |
|
|
|
generator = pipeline( |
|
task="text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device_map="auto" |
|
) |
|
|
|
# Taken from https://arxiv.org/abs/2308.12823 |
|
prompt = "In this letter, we report the discovery of the highest redshift, " \ |
|
"heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \ |
|
"mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. " |
|
|
|
# For reproducibility |
|
torch.manual_seed(42) |
|
|
|
generated_text = generator( |
|
prompt, |
|
do_sample=True, |
|
max_length=512 |
|
) |
|
``` |
|
|
|
## Embedding text with AstroLLaMA |
|
|
|
```python |
|
texts = [ |
|
"Abstract 1", |
|
"Abstract 2" |
|
] |
|
inputs = tokenizer( |
|
texts, |
|
return_tensors="pt", |
|
return_token_type_ids=False, |
|
padding=True, |
|
truncation=True, |
|
max_length=4096 |
|
) |
|
inputs.to(model.device) |
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
|
# Last layer of the hidden states. Get average embedding of all tokens |
|
embeddings = outputs["hidden_states"][-1][:, :, ...].mean().detach().cpu().numpy() |
|
``` |