genaforvena's picture
f
96fba15
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, GenerationConfig
import torch
import threading
from queue import Queue
# Custom Streamer Class
class MyStreamer(TextStreamer):
def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = Queue()
self.stop_signal = None
self.skip_special_tokens = decode_kwargs.pop("skip_special_tokens", True) # Default to True
self.token_cache = [] # Add a token cache
def on_finalized_text(self, text, stream_end=False):
"""Put the new text in the queue."""
self.text_queue.put(text)
def put(self, value):
"""Decode the token and add to buffer."""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("put() only supports a single sequence of tokens at a time.")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the token to the cache
self.token_cache.extend(value.tolist())
# Decode the entire cache
text = self.tokenizer.decode(
self.token_cache,
skip_special_tokens=self.skip_special_tokens,
**self.decode_kwargs,
)
# Check for stop signal (e.g., end of text)
if self.stop_signal and text.endswith(self.stop_signal):
text = text[: -len(self.stop_signal)]
self.on_finalized_text(text, stream_end=True)
self.token_cache = [] # Clear the cache
else:
self.on_finalized_text(text, stream_end=False)
def end(self):
"""Flush the buffer."""
if self.token_cache:
text = self.tokenizer.decode(
self.token_cache,
skip_special_tokens=self.skip_special_tokens,
**self.decode_kwargs,
)
self.on_finalized_text(text, stream_end=True)
self.token_cache = [] # Clear the cache
else:
self.on_finalized_text("", stream_end=True)
# Load the model and tokenizer
model_name = "genaforvena/huivam_finnegan_llama3.2-1b"
model = None
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model/tokenizer: {e}")
exit()
# Move the model to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
if model:
model.to(device)
print(f"Model moved to {device}.")
# Function to generate a streaming response
def reply(prompt):
messages = [{"role": "user", "content": prompt}]
try:
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
# Create a custom streamer
streamer = MyStreamer(tokenizer, skip_prompt=True)
generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
)
def generate():
model.generate(
inputs,
generation_config=generation_config,
streamer=streamer,
max_new_tokens=512, # Adjust as needed
)
thread = threading.Thread(target=generate)
thread.start()
# Yield only the new tokens as they come in
while thread.is_alive():
try:
next_token = streamer.text_queue.get(timeout=0.1)
yield next_token # Yield only the new token
except:
pass
# Yield any remaining text after generation finishes
while not streamer.text_queue.empty():
next_token = streamer.text_queue.get()
yield next_token # Yield only the new token
except Exception as e:
print(f"Error during inference: {e}")
yield f"Error processing your request: {e}"
# Gradio interface
demo = gr.Interface(
fn=reply,
inputs="text",
outputs="text",
)
# Launch the Gradio app
demo.launch(share=True)