import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, GenerationConfig import torch import threading from queue import Queue # Custom Streamer Class class MyStreamer(TextStreamer): def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = Queue() self.stop_signal = None self.skip_special_tokens = decode_kwargs.pop("skip_special_tokens", True) # Default to True self.token_cache = [] # Add a token cache def on_finalized_text(self, text, stream_end=False): """Put the new text in the queue.""" self.text_queue.put(text) def put(self, value): """Decode the token and add to buffer.""" if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("put() only supports a single sequence of tokens at a time.") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return # Add the token to the cache self.token_cache.extend(value.tolist()) # Decode the entire cache text = self.tokenizer.decode( self.token_cache, skip_special_tokens=self.skip_special_tokens, **self.decode_kwargs, ) # Check for stop signal (e.g., end of text) if self.stop_signal and text.endswith(self.stop_signal): text = text[: -len(self.stop_signal)] self.on_finalized_text(text, stream_end=True) self.token_cache = [] # Clear the cache else: self.on_finalized_text(text, stream_end=False) def end(self): """Flush the buffer.""" if self.token_cache: text = self.tokenizer.decode( self.token_cache, skip_special_tokens=self.skip_special_tokens, **self.decode_kwargs, ) self.on_finalized_text(text, stream_end=True) self.token_cache = [] # Clear the cache else: self.on_finalized_text("", stream_end=True) # Load the model and tokenizer model_name = "genaforvena/huivam_finnegan_llama3.2-1b" model = None tokenizer = None try: model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) print("Model and tokenizer loaded successfully.") except Exception as e: print(f"Error loading model/tokenizer: {e}") exit() # Move the model to the appropriate device device = "cuda" if torch.cuda.is_available() else "cpu" if model: model.to(device) print(f"Model moved to {device}.") # Function to generate a streaming response def reply(prompt): messages = [{"role": "user", "content": prompt}] try: inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(device) # Create a custom streamer streamer = MyStreamer(tokenizer, skip_prompt=True) generation_config = GenerationConfig( pad_token_id=tokenizer.pad_token_id, ) def generate(): model.generate( inputs, generation_config=generation_config, streamer=streamer, max_new_tokens=512, # Adjust as needed ) thread = threading.Thread(target=generate) thread.start() # Yield only the new tokens as they come in while thread.is_alive(): try: next_token = streamer.text_queue.get(timeout=0.1) yield next_token # Yield only the new token except: pass # Yield any remaining text after generation finishes while not streamer.text_queue.empty(): next_token = streamer.text_queue.get() yield next_token # Yield only the new token except Exception as e: print(f"Error during inference: {e}") yield f"Error processing your request: {e}" # Gradio interface demo = gr.Interface( fn=reply, inputs="text", outputs="text", ) # Launch the Gradio app demo.launch(share=True)