File size: 1,843 Bytes
8d3b22e bf67fd7 a6f4193 8d3b22e 3ee1192 c4f4220 3ee1192 a6f4193 f6eeb96 a6f4193 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
---
license: mit
datasets:
- universeTBD/arxiv-astro-abstracts-all
language:
- en
metrics:
- perplexity
pipeline_tag: text-generation
tags:
- llama-2
- astronomy
- astrophysics
- arxiv
---
<p><h1>AstroLLaMA</h1></p>
<p align="center">
<img src="https://huggingface.co/universeTBD/astrollama/resolve/main/images/astrollama-logo.png" alt="AstroLLaMA" width="500px"/>
</p>
## Loading the model
```python
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="universeTBD/astrollama"
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="universeTBD/astrollama",
device_map="auto",
)
```
## Generating text from a prompt
```python
import torch
from transformers import pipeline
generator = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# Taken from https://arxiv.org/abs/2308.12823
prompt = "In this letter, we report the discovery of the highest redshift, " \
"heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \
"mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. "
# For reproducibility
torch.manual_seed(42)
generated_text = generator(
prompt,
do_sample=True,
max_length=512
)
```
## Embedding text with AstroLLaMA
```python
texts = [
"Abstract 1",
"Abstract 2"
]
inputs = tokenizer(
text_batch,
return_tensors="pt",
return_token_type_ids=False,
padding=True,
truncation=True,
max_length=4096
)
inputs.to(model.device)
outputs = model(**inputs, output_hidden_states=True)
# Last layer of the hidden states. Get the embedding of the first token in each sequence
embeddings = outputs["hidden_states"][-1][:, 0, ...].detach().cpu().numpy()
``` |