Spaces:
Runtime error
Runtime error
from threading import Thread | |
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
AutoConfig, | |
TextIteratorStreamer | |
) | |
MODEL_ID = "universeTBD/astrollama" | |
WINDOW_SIZE = 4096 | |
DEVICE = "cuda" | |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID) | |
tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path=MODEL_ID | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
pretrained_model_name_or_path=MODEL_ID, | |
config=config, | |
device_map="auto", | |
use_safetensors=True, | |
trust_remote_code=True, | |
load_in_4bit=True, | |
torch_dtype=torch.bfloat16 | |
) | |
def generate_text(prompt: str, | |
max_new_tokens: int = 512, | |
temperature: float = 0.5, | |
top_p: float = 0.95, | |
top_k: int = 50) -> str: | |
# Encode the prompt | |
inputs = tokenizer([prompt], | |
return_tensors='pt', | |
add_special_tokens=False).to(DEVICE) | |
# Prepare arguments for generation | |
input_length = inputs["input_ids"].shape[-1] | |
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length) | |
if temperature >= 1.0: | |
temperature = 0.99 | |
elif temperature <= 0.0: | |
temperature = 0.01 | |
if top_p > 1.0 or top_p <= 0.0: | |
top_p = 1.0 | |
if top_k <= 0: | |
top_k = 100 | |
streamer = TextIteratorStreamer(tokenizer, | |
timeout=10., | |
skip_prompt=True, | |
skip_special_tokens=True) | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
) | |
# Generate text | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = prompt | |
for new_text in streamer: | |
generated_text += new_text | |
return generated_text | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
# Prompt | |
gr.Textbox( | |
label="Prompt", | |
container=False, | |
show_label=False, | |
placeholder="Enter some text...", | |
lines=10, | |
scale=10, | |
), | |
gr.Slider( | |
label="Maximum new tokens", | |
minimum=1, | |
maximum=4096, | |
step=1, | |
value=1024, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.01, | |
maximum=0.99, | |
step=0.01, | |
value=0.5, | |
), | |
gr.Slider( | |
label="Top-p (for sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.95, | |
), | |
gr.Slider( | |
label='Top-k (for sampling)', | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=100, | |
) | |
], | |
outputs=[ | |
gr.Textbox( | |
container=False, | |
show_label=False, | |
placeholder="Generated output...", | |
scale=10, | |
lines=10, | |
) | |
], | |
) | |
demo.queue(max_size=20).launch() | |