leonardlin commited on
Commit
9dfa458
1 Parent(s): f929d3a

remove streamer - threads causing weird issues

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -42,7 +42,6 @@ model = AutoModelForCausalLM.from_pretrained(
42
  bnb_4bit_compute_dtype=torch.bfloat16
43
  ),
44
  )
45
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
46
 
47
  def chat(message, history, system_prompt):
48
  print('---')
@@ -65,7 +64,6 @@ def chat(message, history, system_prompt):
65
 
66
  generate_kwargs = dict(
67
  inputs=input_ids,
68
- streamer=streamer,
69
  max_new_tokens=200,
70
  do_sample=True,
71
  temperature=0.7,
@@ -74,13 +72,11 @@ def chat(message, history, system_prompt):
74
  eos_token_id=tokenizer.eos_token_id,
75
  pad_token_id=tokenizer.eos_token_id,
76
  )
77
- # https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
78
- t = Thread(target=model.generate, kwargs=generate_kwargs)
79
- t.start()
80
- partial_message = ""
81
- for new_token in streamer:
82
- partial_message += new_token # html.escape(new_token)
83
- yield partial_message
84
 
85
 
86
  chat_interface = gr.ChatInterface(
 
42
  bnb_4bit_compute_dtype=torch.bfloat16
43
  ),
44
  )
 
45
 
46
  def chat(message, history, system_prompt):
47
  print('---')
 
64
 
65
  generate_kwargs = dict(
66
  inputs=input_ids,
 
67
  max_new_tokens=200,
68
  do_sample=True,
69
  temperature=0.7,
 
72
  eos_token_id=tokenizer.eos_token_id,
73
  pad_token_id=tokenizer.eos_token_id,
74
  )
75
+
76
+ output_ids = model.generate(**generate_kwargs)
77
+ new_tokens = output_ids[0, input_ids.size(1):]
78
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
79
+ return response
 
 
80
 
81
 
82
  chat_interface = gr.ChatInterface(