astrollama / app.py
Josh Nguyen
Fix a bug in generate_text
d38f5f1
raw
history blame
3.25 kB
from threading import Thread
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
TextIteratorStreamer
)
MODEL_ID = "universeTBD/astrollama"
WINDOW_SIZE = 4096
DEVICE = "cuda"
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=MODEL_ID
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=MODEL_ID,
config=config,
device_map="auto",
use_safetensors=True,
trust_remote_code=True,
load_in_4bit=True,
torch_dtype=torch.bfloat16
)
def generate_text(prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.5,
top_p: float = 0.95,
top_k: int = 50) -> str:
# Encode the prompt
inputs = tokenizer([prompt],
return_tensors='pt',
add_special_tokens=False).to(DEVICE)
# Prepare arguments for generation
input_length = inputs["input_ids"].shape[-1]
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
if temperature >= 1.0:
temperature = 0.99
elif temperature <= 0.0:
temperature = 0.01
if top_p > 1.0 or top_p <= 0.0:
top_p = 1.0
if top_k <= 0:
top_k = 100
streamer = TextIteratorStreamer(tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
)
# Generate text
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = prompt
for new_text in streamer:
generated_text += new_text
return generated_text
demo = gr.Interface(
fn=generate_text,
inputs=[
# Prompt
gr.Textbox(
label="Prompt",
container=False,
show_label=False,
placeholder="Enter some text...",
lines=10,
scale=10,
),
gr.Slider(
label="Maximum new tokens",
minimum=1,
maximum=4096,
step=1,
value=1024,
),
gr.Slider(
label="Temperature",
minimum=0.01,
maximum=0.99,
step=0.01,
value=0.5,
),
gr.Slider(
label="Top-p (for sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label='Top-k (for sampling)',
minimum=1,
maximum=1000,
step=1,
value=100,
)
],
outputs=[
gr.Textbox(
container=False,
show_label=False,
placeholder="Generated output...",
scale=10,
lines=10,
)
],
)
demo.queue(max_size=20).launch()