File size: 4,362 Bytes
e34423b
1307336
2501f58
1307336
 
56529ab
1307336
 
 
 
 
 
 
 
56529ab
1307336
 
 
56529ab
1307336
 
 
 
 
 
56529ab
1307336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56529ab
1307336
 
 
e34423b
1307336
98297f9
1307336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, GenerationConfig
import torch
import threading
from queue import Queue

# Custom Streamer Class
class MyStreamer(TextStreamer):
    def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.text_queue = Queue()
        self.stop_signal = None
        self.skip_special_tokens = decode_kwargs.pop("skip_special_tokens", True)  # Default to True
        self.token_cache = []  # Add a token cache

    def on_finalized_text(self, text, stream_end=False):
        """Put the new text in the queue."""
        self.text_queue.put(text)

    def put(self, value):
        """Decode the token and add to buffer."""
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("put() only supports a single sequence of tokens at a time.")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        # Add the token to the cache
        self.token_cache.extend(value.tolist())

        # Decode the entire cache
        text = self.tokenizer.decode(
            self.token_cache,
            skip_special_tokens=self.skip_special_tokens,
            **self.decode_kwargs,
        )

        # Check for stop signal (e.g., end of text)
        if self.stop_signal and text.endswith(self.stop_signal):
            text = text[: -len(self.stop_signal)]
            self.on_finalized_text(text, stream_end=True)
            self.token_cache = []  # Clear the cache
        else:
            self.on_finalized_text(text, stream_end=False)

    def end(self):
        """Flush the buffer."""
        if self.token_cache:
            text = self.tokenizer.decode(
                self.token_cache,
                skip_special_tokens=self.skip_special_tokens,
                **self.decode_kwargs,
            )
            self.on_finalized_text(text, stream_end=True)
            self.token_cache = []  # Clear the cache
        else:
            self.on_finalized_text("", stream_end=True)

# Load the model and tokenizer
model_name = "genaforvena/huivam_finnegan_llama3.2-1b"
model = None
tokenizer = None
try:
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print("Model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading model/tokenizer: {e}")
    exit()

# Move the model to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
if model:
    model.to(device)
    print(f"Model moved to {device}.")

# Function to generate a streaming response
def reply(prompt):
    messages = [{"role": "user", "content": prompt}]
    try:
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(device)

        # Create a custom streamer
        streamer = MyStreamer(tokenizer, skip_prompt=True)

        generation_config = GenerationConfig(
            pad_token_id=tokenizer.pad_token_id,
        )

        def generate():
            model.generate(
                inputs,
                generation_config=generation_config,
                streamer=streamer,
                max_new_tokens=512,  # Adjust as needed
            )

        thread = threading.Thread(target=generate)
        thread.start()

        # Yield only the new tokens as they come in
        while thread.is_alive():
            try:
                next_token = streamer.text_queue.get(timeout=0.1)
                yield next_token  # Yield only the new token
            except:
                pass

        # Yield any remaining text after generation finishes
        while not streamer.text_queue.empty():
            next_token = streamer.text_queue.get()
            yield next_token  # Yield only the new token

    except Exception as e:
        print(f"Error during inference: {e}")
        yield f"Error processing your request: {e}"

# Gradio interface
demo = gr.Interface(
    fn=reply,
    inputs="text",
    outputs="text",
)

# Launch the Gradio app
demo.launch(share=True)