YANGSongsong commited on
Commit
d429ca3
·
verified ·
1 Parent(s): 89b50c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -19
app.py CHANGED
@@ -1,19 +1,117 @@
1
- import gradio as gr
2
- import pandas as pd
3
- from ultralytics import YOLO
4
- from skimage import data
5
- from PIL import Image
6
-
7
- model = YOLO('yolov8n-cls.pt')
8
- def predict(img):
9
- result = model.predict(source=img)
10
- df = pd.Series(result[0].names).to_frame()
11
- df.columns = ['names']
12
- df['probs'] = result[0].probs
13
- df = df.sort_values('probs',ascending=False)
14
- res = dict(zip(df['names'],df['probs']))
15
- return res
16
- gr.close_all()
17
- demo = gr.Interface(fn = predict,inputs = gr.Image(type='pil'), outputs = gr.Label(num_top_classes=5),
18
- examples = ['cat.jpeg','people.jpeg','coffee.jpeg'])
19
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ MODEL_PATH = os.environ.get('MODEL_PATH', "THUDM/chatglm3-6b")
8
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
11
+
12
+ model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float()
13
+
14
+ class StopOnTokens(StoppingCriteria):
15
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
+ stop_ids = [0, 2]
17
+ for stop_id in stop_ids:
18
+ if input_ids[0][-1] == stop_id:
19
+ return True
20
+ return False
21
+
22
+ def parse_text(text):
23
+ lines = text.split("\n")
24
+ lines = [line for line in lines if line != ""]
25
+ count = 0
26
+ for i, line in enumerate(lines):
27
+ if "```" in line:
28
+ count += 1
29
+ items = line.split('`')
30
+ if count % 2 == 1:
31
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
32
+ else:
33
+ lines[i] = f'<br></code></pre>'
34
+ else:
35
+ if i > 0:
36
+ if count % 2 == 1:
37
+ line = line.replace("`", "\`")
38
+ line = line.replace("<", "&lt;")
39
+ line = line.replace(">", "&gt;")
40
+ line = line.replace(" ", "&nbsp;")
41
+ line = line.replace("*", "&ast;")
42
+ line = line.replace("_", "&lowbar;")
43
+ line = line.replace("-", "&#45;")
44
+ line = line.replace(".", "&#46;")
45
+ line = line.replace("!", "&#33;")
46
+ line = line.replace("(", "&#40;")
47
+ line = line.replace(")", "&#41;")
48
+ line = line.replace("$", "&#36;")
49
+ lines[i] = "<br>" + line
50
+ text = "".join(lines)
51
+ return text
52
+
53
+ def predict(history, max_length, top_p, temperature):
54
+ stop = StopOnTokens()
55
+ messages = []
56
+ for idx, (user_msg, model_msg) in enumerate(history):
57
+ if idx == len(history) - 1 and not model_msg:
58
+ messages.append({"role": "user", "content": user_msg})
59
+ break
60
+ if user_msg:
61
+ messages.append({"role": "user", "content": user_msg})
62
+ if model_msg:
63
+ messages.append({"role": "assistant", "content": model_msg})
64
+
65
+ print("\n\n====conversation====\n", messages)
66
+ model_inputs = tokenizer.apply_chat_template(messages,
67
+ add_generation_prompt=True,
68
+ tokenize=True,
69
+ return_tensors="pt").to(next(model.parameters()).device)
70
+ streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
71
+ generate_kwargs = {
72
+ "input_ids": model_inputs,
73
+ "streamer": streamer,
74
+ "max_new_tokens": max_length,
75
+ "do_sample": True,
76
+ "top_p": top_p,
77
+ "temperature": temperature,
78
+ "stopping_criteria": StoppingCriteriaList([stop]),
79
+ "repetition_penalty": 1.2,
80
+ }
81
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
82
+ t.start()
83
+
84
+ for new_token in streamer:
85
+ if new_token != '':
86
+ history[-1][1] += new_token
87
+ yield history
88
+
89
+
90
+ with gr.Blocks() as demo:
91
+ gr.HTML("""<h1 align="center">ChatGLGradio Simple Demo</h1>""")
92
+ chatbot = gr.Chatbot()
93
+
94
+ with gr.Row():
95
+ with gr.Column(scale=4):
96
+ with gr.Column(scale=12):
97
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
98
+ with gr.Column(min_width=32, scale=1):
99
+ submitBtn = gr.Button("Submit")
100
+ with gr.Column(scale=1):
101
+ emptyBtn = gr.Button("Clear History")
102
+ max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
103
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
104
+ temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
105
+
106
+
107
+ def user(query, history):
108
+ return "", history + [[parse_text(query), ""]]
109
+
110
+
111
+ submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
112
+ predict, [chatbot, max_length, top_p, temperature], chatbot
113
+ )
114
+ emptyBtn.click(lambda: None, None, chatbot, queue=False)
115
+
116
+ demo.queue()
117
+ demo.launch(server_name="127.0.0.1", server_port=8501, inbrowser=True, share=False)