Walmart-the-bag commited on
Commit
2bbf1ec
Β·
verified Β·
1 Parent(s): db22cf6

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -0
main.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import TextIteratorStreamer
3
+ from threading import Thread
4
+ from transformers import StoppingCriteria, StoppingCriteriaList
5
+ import torch
6
+
7
+ model_name = "microsoft/Phi-3-mini-128k-instruct"
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='cuda')
11
+
12
+
13
+ model = model.to('cuda:0')
14
+
15
+ class StopOnTokens(StoppingCriteria):
16
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
+ stop_ids = [29, 0]
18
+ for stop_id in stop_ids:
19
+ if input_ids[0][-1] == stop_id:
20
+ return True
21
+ return False
22
+ @spaces.GPU(duration=180)
23
+ def predict(message, history):
24
+ history_transformer_format = history + [[message, ""]]
25
+ stop = StopOnTokens()
26
+ messages = "".join(["".join(["<|end|>\n<|user|>\n"+item[0], "<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
27
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
28
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
29
+ generate_kwargs = dict(
30
+ model_inputs,
31
+ streamer=streamer,
32
+ max_new_tokens=4096,
33
+ do_sample=True,
34
+ top_p=0.9,
35
+ top_k=40,
36
+ temperature=0.9,
37
+ num_beams=1,
38
+ stopping_criteria=StoppingCriteriaList([stop])
39
+ )
40
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
41
+ t.start()
42
+ partial_message = ""
43
+ for new_token in streamer:
44
+ if new_token != '<':
45
+ partial_message += new_token
46
+ yield partial_message
47
+
48
+
49
+ demo = gr.ChatInterface(fn=predict, examples=["What is life?"], title="AI", fill_height=True)
50
+
51
+ demo.launch(show_api=False)