genaforvena commited on
Commit
1307336
·
1 Parent(s): 6384c62
Files changed (1) hide show
  1. app.py +125 -20
app.py CHANGED
@@ -1,29 +1,134 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from peft import PeftModel
4
  import torch
 
 
5
 
6
- base_model_name = "unsloth/Llama-3.2-1B-Instruct"
7
- base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
 
 
 
 
 
 
8
 
9
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
 
10
 
11
- peft_model_path = "genaforvena/huivam_finnegan_llama3.2-1b"
12
- model = PeftModel.from_pretrained(base_model, peft_model_path)
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- model.to(device)
 
 
16
 
 
17
  def reply(prompt):
18
- print("prompt: " + prompt)
19
- inputs = tokenizer.encode(prompt, return_tensors="pt")
20
- print("tokenized")
21
- output = model.generate(inputs, max_new_tokens=100, do_sample=True, top_p=0.95, top_k=50)
22
- print("generated")
23
- text = tokenizer.decode(output[0], skip_special_tokens=True)
24
-
25
- print("text: " + text)
26
- return text
27
-
28
- demo = gr.Interface(fn=reply, inputs="text", outputs="text")
29
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, GenerationConfig
 
3
  import torch
4
+ import threading
5
+ from queue import Queue
6
 
7
+ # Custom Streamer Class
8
+ class MyStreamer(TextStreamer):
9
+ def __init__(self, tokenizer, skip_prompt=True, **decode_kwargs):
10
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
11
+ self.text_queue = Queue()
12
+ self.stop_signal = None
13
+ self.skip_special_tokens = decode_kwargs.pop("skip_special_tokens", True) # Default to True
14
+ self.token_cache = [] # Add a token cache
15
 
16
+ def on_finalized_text(self, text, stream_end=False):
17
+ """Put the new text in the queue."""
18
+ self.text_queue.put(text)
19
 
20
+ def put(self, value):
21
+ """Decode the token and add to buffer."""
22
+ if len(value.shape) > 1 and value.shape[0] > 1:
23
+ raise ValueError("put() only supports a single sequence of tokens at a time.")
24
+ elif len(value.shape) > 1:
25
+ value = value[0]
26
 
27
+ if self.skip_prompt and self.next_tokens_are_prompt:
28
+ self.next_tokens_are_prompt = False
29
+ return
30
+
31
+ # Add the token to the cache
32
+ self.token_cache.extend(value.tolist())
33
+
34
+ # Decode the entire cache
35
+ text = self.tokenizer.decode(
36
+ self.token_cache,
37
+ skip_special_tokens=self.skip_special_tokens,
38
+ **self.decode_kwargs,
39
+ )
40
+
41
+ # Check for stop signal (e.g., end of text)
42
+ if self.stop_signal and text.endswith(self.stop_signal):
43
+ text = text[: -len(self.stop_signal)]
44
+ self.on_finalized_text(text, stream_end=True)
45
+ self.token_cache = [] # Clear the cache
46
+ else:
47
+ self.on_finalized_text(text, stream_end=False)
48
+
49
+ def end(self):
50
+ """Flush the buffer."""
51
+ if self.token_cache:
52
+ text = self.tokenizer.decode(
53
+ self.token_cache,
54
+ skip_special_tokens=self.skip_special_tokens,
55
+ **self.decode_kwargs,
56
+ )
57
+ self.on_finalized_text(text, stream_end=True)
58
+ self.token_cache = [] # Clear the cache
59
+ else:
60
+ self.on_finalized_text("", stream_end=True)
61
+
62
+ # Load the model and tokenizer
63
+ model_name = "genaforvena/huivam_finnegan_llama3.2-1b"
64
+ model = None
65
+ tokenizer = None
66
+ try:
67
+ model = AutoModelForCausalLM.from_pretrained(model_name)
68
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
69
+ print("Model and tokenizer loaded successfully.")
70
+ except Exception as e:
71
+ print(f"Error loading model/tokenizer: {e}")
72
+ exit()
73
+
74
+ # Move the model to the appropriate device
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ if model:
77
+ model.to(device)
78
+ print(f"Model moved to {device}.")
79
 
80
+ # Function to generate a streaming response
81
  def reply(prompt):
82
+ messages = [{"role": "user", "content": prompt}]
83
+ try:
84
+ inputs = tokenizer.apply_chat_template(
85
+ messages,
86
+ tokenize=True,
87
+ add_generation_prompt=True,
88
+ return_tensors="pt",
89
+ ).to(device)
90
+
91
+ # Create a custom streamer
92
+ streamer = MyStreamer(tokenizer, skip_prompt=True)
93
+
94
+ generation_config = GenerationConfig(
95
+ pad_token_id=tokenizer.pad_token_id,
96
+ )
97
+
98
+ def generate():
99
+ model.generate(
100
+ inputs,
101
+ generation_config=generation_config,
102
+ streamer=streamer,
103
+ max_new_tokens=512, # Adjust as needed
104
+ )
105
+
106
+ thread = threading.Thread(target=generate)
107
+ thread.start()
108
+
109
+ # Yield only the new tokens as they come in
110
+ while thread.is_alive():
111
+ try:
112
+ next_token = streamer.text_queue.get(timeout=0.1)
113
+ yield next_token # Yield only the new token
114
+ except:
115
+ pass
116
+
117
+ # Yield any remaining text after generation finishes
118
+ while not streamer.text_queue.empty():
119
+ next_token = streamer.text_queue.get()
120
+ yield next_token # Yield only the new token
121
+
122
+ except Exception as e:
123
+ print(f"Error during inference: {e}")
124
+ yield f"Error processing your request: {e}"
125
+
126
+ # Gradio interface
127
+ demo = gr.Interface(
128
+ fn=reply,
129
+ inputs="text",
130
+ outputs="text",
131
+ )
132
+
133
+ # Launch the Gradio app
134
+ demo.launch(share=True)