Somekindofa commited on
Commit
0aba972
·
1 Parent(s): 03a78ae

Refactor generate function to handle input token length and implement threading for model generation

Browse files
Files changed (1) hide show
  1. app.py +30 -5
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  import os
7
  import torch
8
  from typing import Optional, Iterator
 
9
 
10
 
11
  # Initialize logging and device information
@@ -79,11 +80,35 @@ def generate(
79
  conversation.append({"role": "user", "content": message})
80
 
81
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
82
- print("\nInput Ids:\t", input_ids, "Type:\t", type(input_ids))
83
- for i, token in enumerate(input_ids):
84
- print(f"ID {i}:", token[0])
85
- print(f"ID {token[0]} -> '{tokenizer.decode(token[0])}'")
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  chat_interface = gr.ChatInterface(
89
  fn=generate,
 
6
  import os
7
  import torch
8
  from typing import Optional, Iterator
9
+ from threading import Thread
10
 
11
 
12
  # Initialize logging and device information
 
80
  conversation.append({"role": "user", "content": message})
81
 
82
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
83
+ print("Input Ids Shape: ", input_ids.shape)
84
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
85
+ gr.Warning(f"Trimmed the input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
86
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH]
87
+ input_ids = input_ids.to(model.device)
88
+
89
+ streamer = TextIteratorStreamer(tokenizer,
90
+ timeout=10.0,
91
+ skip_prompt=True,
92
+ skip_special_tokens=True)
93
+ generate_kwargs = dict(
94
+ {"input_ids": input_ids},
95
+ streamer=streamer,
96
+ max_new_tokens=max_new_tokens,
97
+ do_sample=True,
98
+ top_p=top_p,
99
+ top_k=top_k,
100
+ temperature=temperature,
101
+ num_beams=1,
102
+ repetition_penalty=repetition_penalty
103
+ )
104
+
105
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
106
+ t.start()
107
+
108
+ outputs = []
109
+ for text in streamer:
110
+ outputs.append(text)
111
+ yield "".join(outputs)
112
 
113
  chat_interface = gr.ChatInterface(
114
  fn=generate,