Sakalti commited on
Commit
b26e890
·
verified ·
1 Parent(s): 451d731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
4
 
5
  import gradio as gr
6
  import spaces
@@ -29,7 +30,6 @@ model = AutoModelForCausalLM.from_pretrained(
29
  model.config.sliding_window = 4096
30
  model.eval()
31
 
32
-
33
  @spaces.GPU
34
  def generate(
35
  message: str,
@@ -46,11 +46,11 @@ def generate(
46
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
- input_ids = input_ids.to(model.device)
50
 
51
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
52
  generate_kwargs = dict(
53
- {"input_ids": input_ids},
54
  streamer=streamer,
55
  max_new_tokens=max_new_tokens,
56
  do_sample=True,
@@ -64,11 +64,15 @@ def generate(
64
  t.start()
65
 
66
  outputs = []
67
- for text in streamer:
68
- outputs.append(text)
 
 
 
 
 
69
  yield "".join(outputs)
70
 
71
-
72
  demo = gr.ChatInterface(
73
  fn=generate,
74
  type="messages",
@@ -122,6 +126,5 @@ demo = gr.ChatInterface(
122
  cache_examples=False,
123
  )
124
 
125
-
126
  if __name__ == "__main__":
127
- demo.launch()
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
+ import queue
5
 
6
  import gradio as gr
7
  import spaces
 
30
  model.config.sliding_window = 4096
31
  model.eval()
32
 
 
33
  @spaces.GPU
34
  def generate(
35
  message: str,
 
46
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
+ input_ids = input_ids.to(device)
50
 
51
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
52
  generate_kwargs = dict(
53
+ input_ids=input_ids,
54
  streamer=streamer,
55
  max_new_tokens=max_new_tokens,
56
  do_sample=True,
 
64
  t.start()
65
 
66
  outputs = []
67
+ try:
68
+ for text in streamer:
69
+ outputs.append(text)
70
+ yield "".join(outputs)
71
+ except queue.Empty:
72
+ # キューが空になった場合の処理
73
+ gr.Warning("生成プロセスがタイムアウトしました。")
74
  yield "".join(outputs)
75
 
 
76
  demo = gr.ChatInterface(
77
  fn=generate,
78
  type="messages",
 
126
  cache_examples=False,
127
  )
128
 
 
129
  if __name__ == "__main__":
130
+ demo.launch()