from threading import Thread import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer ) MODEL_ID = "universeTBD/astrollama" WINDOW_SIZE = 4096 DEVICE = "cuda" config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID) tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=MODEL_ID ) model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=MODEL_ID, config=config, device_map="auto", use_safetensors=True, trust_remote_code=True, load_in_4bit=True, torch_dtype=torch.bfloat16 ) def generate_text(prompt: str, max_new_tokens: int = 512, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50) -> str: # Encode the prompt inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(DEVICE) # Prepare arguments for generation input_length = inputs["input_ids"].shape[-1] max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length) if temperature >= 1.0: temperature = 0.99 elif temperature <= 0.0: temperature = 0.01 if top_p > 1.0 or top_p <= 0.0: top_p = 1.0 if top_k <= 0: top_k = 100 streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, ) # Generate text thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = prompt for new_text in streamer: generated_text += new_text return generated_text demo = gr.Interface( fn=generate_text, inputs=[ # Prompt gr.Textbox( label="Prompt", container=False, show_label=False, placeholder="Enter some text...", lines=10, scale=10, ), gr.Slider( label="Maximum new tokens", minimum=1, maximum=4096, step=1, value=1024, ), gr.Slider( label="Temperature", minimum=0.01, maximum=0.99, step=0.01, value=0.5, ), gr.Slider( label="Top-p (for sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.95, ), gr.Slider( label='Top-k (for sampling)', minimum=1, maximum=1000, step=1, value=100, ) ], outputs=[ gr.Textbox( container=False, show_label=False, placeholder="Generated output...", scale=10, lines=10, ) ], ) demo.queue(max_size=20).launch()