Sakalti commited on
Commit
cb329ff
·
verified ·
1 Parent(s): fc27e2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -2,17 +2,16 @@
2
 
3
  import os
4
  from threading import Thread
 
5
  from typing import Iterator
6
 
7
  import gradio as gr
8
  import spaces
9
  import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
  DESCRIPTION = "# Sakaltum-7B-chat"
13
-
14
- if not torch.cuda.is_available():
15
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo might be slower on CPU.</p>"
16
 
17
  MAX_MAX_NEW_TOKENS = 2048
18
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -22,7 +21,7 @@ model_id = "sakaltcommunity/sakaltum-7b"
22
  if torch.cuda.is_available():
23
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
24
  else:
25
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
26
  model.eval()
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -56,26 +55,34 @@ def generate(
56
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
57
  input_ids = input_ids.to(model.device)
58
 
59
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
60
- generate_kwargs = dict(
61
- input_ids=input_ids,
62
- streamer=streamer,
63
- max_new_tokens=max_new_tokens,
64
- do_sample=True,
65
- top_p=top_p,
66
- top_k=top_k,
67
- temperature=temperature,
68
- num_beams=1,
69
- repetition_penalty=repetition_penalty,
70
- pad_token_id=tokenizer.eos_token_id,
71
- )
72
- t = Thread(target=model.generate, kwargs=generate_kwargs)
73
- t.start()
 
 
74
 
75
  outputs = []
76
- for text in streamer:
77
- outputs.append(text)
78
- yield "".join(outputs)
 
 
 
 
 
 
79
 
80
 
81
  demo = gr.ChatInterface(
 
2
 
3
  import os
4
  from threading import Thread
5
+ from queue import Queue, Empty
6
  from typing import Iterator
7
 
8
  import gradio as gr
9
  import spaces
10
  import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
  DESCRIPTION = "# Sakaltum-7B-chat"
14
+ DESCRIPTION += "\n<p>現在の環境に合わせて最適化されています。</p>"
 
 
15
 
16
  MAX_MAX_NEW_TOKENS = 2048
17
  DEFAULT_MAX_NEW_TOKENS = 1024
 
21
  if torch.cuda.is_available():
22
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
23
  else:
24
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
25
  model.eval()
26
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
55
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
56
  input_ids = input_ids.to(model.device)
57
 
58
+ output_queue = Queue()
59
+ def inference():
60
+ outputs = model.generate(
61
+ input_ids=input_ids,
62
+ max_new_tokens=max_new_tokens,
63
+ do_sample=True,
64
+ top_p=top_p,
65
+ top_k=top_k,
66
+ temperature=temperature,
67
+ repetition_penalty=repetition_penalty,
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ )
70
+ for token in tokenizer.decode(outputs[0], skip_special_tokens=True).split():
71
+ output_queue.put(token)
72
+ output_queue.put(None) # 終了シグナル
73
+
74
+ Thread(target=inference).start()
75
 
76
  outputs = []
77
+ while True:
78
+ try:
79
+ token = output_queue.get(timeout=20.0) # タイムアウト設定
80
+ if token is None:
81
+ break
82
+ outputs.append(token)
83
+ yield "".join(outputs)
84
+ except Empty:
85
+ yield "現在応答を生成中です。しばらくお待ちください。"
86
 
87
 
88
  demo = gr.ChatInterface(