asderene commited on
Commit
34d250f
·
verified ·
1 Parent(s): a8985ff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo5
2
+ # tttt
3
+ import os
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
8
+ from transformers import is_torch_npu_available
9
+ from threading import Thread
10
+
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-14B-Chat")
13
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-14B-Chat", torch_dtype=torch.bfloat16)
14
+ if is_torch_npu_available():
15
+ model.to("npu:0")
16
+ elif torch.cuda.is_available():
17
+ mode.to("cuda:0")
18
+
19
+
20
+ class StopOnTokens(StoppingCriteria):
21
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22
+ stop_ids = [2]
23
+ for stop_id in stop_ids:
24
+ if input_ids[0][-1] == stop_id:
25
+ return True
26
+ return False
27
+
28
+
29
+ def predict(message, history):
30
+ #if is_torch_npu_available():
31
+ # torch.npu.set_device(model.device)
32
+ stop = StopOnTokens()
33
+ conversation = []
34
+
35
+ for user, assistant in history:
36
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
37
+
38
+ conversation.append({"role": "user", "content": message})
39
+ print(f'>>>conversation={conversation}', flush=True)
40
+ prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
41
+ model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
42
+ streamer = TextIteratorStreamer(tokenizer, timeout=100., skip_prompt=True, skip_special_tokens=True)
43
+ generate_kwargs = dict(
44
+ model_inputs,
45
+ streamer=streamer,
46
+ max_new_tokens=1024,
47
+ do_sample=True,
48
+ top_p=0.95,
49
+ top_k=50,
50
+ temperature=0.7,
51
+ repetition_penalty=1.0,
52
+ num_beams=1,
53
+ stopping_criteria=StoppingCriteriaList([stop])
54
+ )
55
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
56
+ t.start()
57
+ partial_message = ""
58
+ for new_token in streamer:
59
+ partial_message += new_token
60
+ if '</s>' in partial_message:
61
+ break
62
+ yield partial_message
63
+
64
+
65
+ # Setting up the Gradio chat interface.
66
+ gr.ChatInterface(predict,
67
+ title="Qwen1.5 0.5B Chat Demo",
68
+ description="Warning. All answers are generated and may contain inaccurate information.",
69
+ examples=['How do you cook fish?', 'Who is the president of the United States?']
70
+ ).launch()
71
+