ZennyKenny commited on
Commit
4e01411
·
verified ·
1 Parent(s): c4c5c31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -3,7 +3,8 @@
3
  import spaces
4
  import gradio as gr
5
  from peft import PeftModel
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
7
 
8
  # Load the base model
9
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -29,19 +30,23 @@ def generate_response(prompt):
29
  )
30
  inputs = tokenizer(reasoning_prompt, return_tensors="pt").to(model.device)
31
 
32
- # Streamed response
33
- stream = model.generate(
 
 
34
  **inputs,
35
- max_new_tokens=300, # Increased token limit
36
  do_sample=True,
37
  temperature=0.8,
38
  top_p=0.95,
39
- stream=True
40
  )
41
-
42
- # Yield output tokens in real-time
43
- for chunk in stream:
44
- yield tokenizer.decode(chunk[0], skip_special_tokens=True)
 
 
45
 
46
  demo = gr.Interface(
47
  fn=generate_response,
@@ -52,4 +57,4 @@ demo = gr.Interface(
52
  live=True
53
  )
54
 
55
- demo.launch()
 
3
  import spaces
4
  import gradio as gr
5
  from peft import PeftModel
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
+ import torch
8
 
9
  # Load the base model
10
  base_model = AutoModelForCausalLM.from_pretrained(
 
30
  )
31
  inputs = tokenizer(reasoning_prompt, return_tensors="pt").to(model.device)
32
 
33
+ # Using TextIteratorStreamer for streaming responses
34
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
35
+
36
+ generation_kwargs = dict(
37
  **inputs,
38
+ max_new_tokens=300,
39
  do_sample=True,
40
  temperature=0.8,
41
  top_p=0.95,
42
+ streamer=streamer
43
  )
44
+
45
+ thread = torch.Thread(target=model.generate, kwargs=generation_kwargs)
46
+ thread.start()
47
+
48
+ for new_text in streamer:
49
+ yield new_text
50
 
51
  demo = gr.Interface(
52
  fn=generate_response,
 
57
  live=True
58
  )
59
 
60
+ demo.launch(share=True)