astrollama / README.md
joshnguyen's picture
Update emb method
be83937
|
raw
history blame
1.83 kB
metadata
license: mit
datasets:
  - universeTBD/arxiv-astro-abstracts-all
language:
  - en
metrics:
  - perplexity
pipeline_tag: text-generation
tags:
  - llama-2
  - astronomy
  - astrophysics
  - arxiv

AstroLLaMA

AstroLLaMA

Loading the model

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="universeTBD/astrollama"
)
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="universeTBD/astrollama",
    device_map="auto",
)

Generating text from a prompt

import torch
from transformers import pipeline

generator = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto"
)

# Taken from https://arxiv.org/abs/2308.12823
prompt = "In this letter, we report the discovery of the highest redshift, " \
    "heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \
    "mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. "

# For reproducibility
torch.manual_seed(42)

generated_text = generator(
    prompt,
    do_sample=True,
    max_length=512
)

Embedding text with AstroLLaMA

texts = [
    "Abstract 1",
    "Abstract 2"
]
inputs = tokenizer(
    texts,
    return_tensors="pt",
    return_token_type_ids=False,
    padding=True,
    truncation=True,
    max_length=4096
)
inputs.to(model.device)
outputs = model(**inputs, output_hidden_states=True)

# Last layer of the hidden states. Get average embedding of all tokens
embeddings = outputs["hidden_states"][-1][:, :, ...].mean().detach().cpu().numpy()