File size: 4,362 Bytes
e34423b 1307336 2501f58 1307336 56529ab 1307336 56529ab 1307336 56529ab 1307336 56529ab 1307336 56529ab 1307336 e34423b 1307336 98297f9 1307336 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, GenerationConfig
import torch
import threading
from queue import Queue
# Custom Streamer Class
class MyStreamer(TextStreamer):
def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = Queue()
self.stop_signal = None
self.skip_special_tokens = decode_kwargs.pop("skip_special_tokens", True) # Default to True
self.token_cache = [] # Add a token cache
def on_finalized_text(self, text, stream_end=False):
"""Put the new text in the queue."""
self.text_queue.put(text)
def put(self, value):
"""Decode the token and add to buffer."""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("put() only supports a single sequence of tokens at a time.")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the token to the cache
self.token_cache.extend(value.tolist())
# Decode the entire cache
text = self.tokenizer.decode(
self.token_cache,
skip_special_tokens=self.skip_special_tokens,
**self.decode_kwargs,
)
# Check for stop signal (e.g., end of text)
if self.stop_signal and text.endswith(self.stop_signal):
text = text[: -len(self.stop_signal)]
self.on_finalized_text(text, stream_end=True)
self.token_cache = [] # Clear the cache
else:
self.on_finalized_text(text, stream_end=False)
def end(self):
"""Flush the buffer."""
if self.token_cache:
text = self.tokenizer.decode(
self.token_cache,
skip_special_tokens=self.skip_special_tokens,
**self.decode_kwargs,
)
self.on_finalized_text(text, stream_end=True)
self.token_cache = [] # Clear the cache
else:
self.on_finalized_text("", stream_end=True)
# Load the model and tokenizer
model_name = "genaforvena/huivam_finnegan_llama3.2-1b"
model = None
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model/tokenizer: {e}")
exit()
# Move the model to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
if model:
model.to(device)
print(f"Model moved to {device}.")
# Function to generate a streaming response
def reply(prompt):
messages = [{"role": "user", "content": prompt}]
try:
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
# Create a custom streamer
streamer = MyStreamer(tokenizer, skip_prompt=True)
generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
)
def generate():
model.generate(
inputs,
generation_config=generation_config,
streamer=streamer,
max_new_tokens=512, # Adjust as needed
)
thread = threading.Thread(target=generate)
thread.start()
# Yield only the new tokens as they come in
while thread.is_alive():
try:
next_token = streamer.text_queue.get(timeout=0.1)
yield next_token # Yield only the new token
except:
pass
# Yield any remaining text after generation finishes
while not streamer.text_queue.empty():
next_token = streamer.text_queue.get()
yield next_token # Yield only the new token
except Exception as e:
print(f"Error during inference: {e}")
yield f"Error processing your request: {e}"
# Gradio interface
demo = gr.Interface(
fn=reply,
inputs="text",
outputs="text",
)
# Launch the Gradio app
demo.launch(share=True)
|