[email protected] commited on
Commit
86216f9
·
1 Parent(s): c5af4d7

Stream output with TextIteratorStreamer

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -9,8 +9,9 @@ model_id: "eltorio/Llama-3.2-3B-appreciation"
9
  Author: Ronan Le Meillat
10
  License: AGPL-3.0
11
  """
 
12
  import gradio as gr
13
- from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
14
  from peft import AutoPeftModelForCausalLM
15
  import torch
16
  import os
@@ -77,22 +78,34 @@ def infere(trimestre: str, moyenne_1: float,moyenne_2: float,moyenne_3: float, c
77
  """,
78
  duration=500)
79
  messages = get_conversation(trimestre, moyenne_1, moyenne_2, moyenne_3, comportement, participation, travail)
 
80
  # Tokenize the input
81
  inputs = tokenizer.apply_chat_template(
82
  messages,
83
  tokenize = True,
84
  add_generation_prompt = True,
85
  return_tensors = "pt",).to(device)
86
- # Generate the output
87
- outputs = model.generate(input_ids = inputs,
88
- max_new_tokens = 90,
89
- use_cache = True,
90
- temperature = 1.5,
91
- min_p = 0.1,
92
- pad_token_id=tokenizer.eos_token_id,)
93
- # Decodes the returned tokens
94
- decoded_sequences = tokenizer.batch_decode(outputs[:, inputs.shape[1]:],skip_special_tokens=True)[0]
95
- return decoded_sequences
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Create a Gradio interface with the infere function and specified title and descriptions
98
  autoeval = gr.Interface(fn=infere, inputs=[
 
9
  Author: Ronan Le Meillat
10
  License: AGPL-3.0
11
  """
12
+ from threading import Thread
13
  import gradio as gr
14
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
15
  from peft import AutoPeftModelForCausalLM
16
  import torch
17
  import os
 
78
  """,
79
  duration=500)
80
  messages = get_conversation(trimestre, moyenne_1, moyenne_2, moyenne_3, comportement, participation, travail)
81
+
82
  # Tokenize the input
83
  inputs = tokenizer.apply_chat_template(
84
  messages,
85
  tokenize = True,
86
  add_generation_prompt = True,
87
  return_tensors = "pt",).to(device)
88
+
89
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
90
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
91
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
92
+ generate_kwargs = dict(
93
+ input_ids = inputs,
94
+ streamer=streamer,
95
+ max_new_tokens=90,
96
+ use_cache = True,
97
+ pad_token_id=tokenizer.eos_token_id,
98
+ )
99
+ generation_thread = Thread(target=model.generate, kwargs=generate_kwargs)
100
+ generation_thread.start()
101
+
102
+ # Pull the generated text from the streamer, and update the model output.
103
+ model_output = ""
104
+ for new_text in streamer:
105
+ model_output += new_text
106
+ yield model_output
107
+ return model_output
108
+
109
 
110
  # Create a Gradio interface with the infere function and specified title and descriptions
111
  autoeval = gr.Interface(fn=infere, inputs=[