nikravan commited on
Commit
86d1d69
·
verified ·
1 Parent(s): ee4e8c2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ import os
7
+ from threading import Thread
8
+
9
+ import pymupdf
10
+ import docx
11
+ from pptx import Presentation
12
+
13
+
14
+ MODEL_LIST = ["nikravan/glm_4vq"]
15
+ #MODEL_LIST = ["../Model_4b_sharded"]
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
+ MODEL_ID = MODEL_LIST[0]
18
+ MODEL_NAME = "GLM-4vq"
19
+
20
+ TITLE = "<h1>3ML-bot</h1>"
21
+
22
+ DESCRIPTION = f"""
23
+ <center>
24
+ <p>😊 A Multi-Modal Multi-Lingual(3ML) Chat.
25
+ <br>
26
+ 🚀 MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a>
27
+ </center>"""
28
+
29
+ CSS = """
30
+ h1 {
31
+ text-align: center;
32
+ display: block;
33
+ }
34
+ """
35
+
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ MODEL_ID,
38
+ torch_dtype=torch.bfloat16,
39
+ low_cpu_mem_usage=True,
40
+ trust_remote_code=True
41
+ )
42
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
43
+ model.eval()
44
+
45
+
46
+ def extract_text(path):
47
+ return open(path, 'r').read()
48
+
49
+
50
+ def extract_pdf(path):
51
+ doc = pymupdf.open(path)
52
+ text = ""
53
+ for page in doc:
54
+ text += page.get_text()
55
+ return text
56
+
57
+
58
+ def extract_docx(path):
59
+ doc = docx.Document(path)
60
+ data = []
61
+ for paragraph in doc.paragraphs:
62
+ data.append(paragraph.text)
63
+ content = '\n\n'.join(data)
64
+ return content
65
+
66
+
67
+ def extract_pptx(path):
68
+ prs = Presentation(path)
69
+ text = ""
70
+ for slide in prs.slides:
71
+ for shape in slide.shapes:
72
+ if hasattr(shape, "text"):
73
+ text += shape.text + "\n"
74
+ return text
75
+
76
+
77
+ def mode_load(path):
78
+ choice = ""
79
+ file_type = path.split(".")[-1]
80
+ print(file_type)
81
+ if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
82
+ if file_type.endswith("pdf"):
83
+ content = extract_pdf(path)
84
+ elif file_type.endswith("docx"):
85
+ content = extract_docx(path)
86
+ elif file_type.endswith("pptx"):
87
+ content = extract_pptx(path)
88
+ else:
89
+ content = extract_text(path)
90
+ choice = "doc"
91
+ print(content[:100])
92
+ return choice, content[:5000]
93
+
94
+
95
+ elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
96
+ content = Image.open(path).convert('RGB')
97
+ choice = "image"
98
+ return choice, content
99
+
100
+ else:
101
+ raise gr.Error("Oops, unsupported files.")
102
+
103
+
104
+ @spaces.GPU()
105
+ def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
106
+ print(f'message is - {message}')
107
+ print(f'history is - {history}')
108
+ conversation = []
109
+ prompt_files = []
110
+ if message["files"]:
111
+ choice, contents = mode_load(message["files"][-1])
112
+ if choice == "image":
113
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
114
+ elif choice == "doc":
115
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
116
+ conversation.append({"role": "user", "content": format_msg})
117
+ else:
118
+ if len(history) == 0:
119
+ # raise gr.Error("Please upload an image first.")
120
+ contents = None
121
+ conversation.append({"role": "user", "content": message['text']})
122
+ else:
123
+ # image = Image.open(history[0][0][0])
124
+ for prompt, answer in history:
125
+ if answer is None:
126
+ prompt_files.append(prompt[0])
127
+ conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
128
+ else:
129
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
130
+ if len(prompt_files) > 0:
131
+ choice, contents = mode_load(prompt_files[-1])
132
+ else:
133
+ choice = ""
134
+ conversation.append({"role": "user", "image": "", "content": message['text']})
135
+
136
+
137
+ if choice == "image":
138
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
139
+ elif choice == "doc":
140
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
141
+ conversation.append({"role": "user", "content": format_msg})
142
+ print(f"Conversation is -\n{conversation}")
143
+
144
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
145
+ return_tensors="pt", return_dict=True).to(model.device)
146
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
147
+
148
+ generate_kwargs = dict(
149
+ max_length=max_length,
150
+ streamer=streamer,
151
+ do_sample=True,
152
+ top_p=top_p,
153
+ top_k=top_k,
154
+ temperature=temperature,
155
+ repetition_penalty=penalty,
156
+ eos_token_id=[151329, 151336, 151338],
157
+ )
158
+ gen_kwargs = {**input_ids, **generate_kwargs}
159
+
160
+ with torch.no_grad():
161
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
162
+ thread.start()
163
+ buffer = ""
164
+ for new_text in streamer:
165
+ buffer += new_text
166
+ yield buffer
167
+
168
+
169
+ chatbot = gr.Chatbot(
170
+ #rtl=True,
171
+ )
172
+ chat_input = gr.MultimodalTextbox(
173
+ interactive=True,
174
+ placeholder="Enter message or upload a file ...",
175
+ show_label=False,
176
+ #rtl=True,
177
+
178
+
179
+
180
+ )
181
+ EXAMPLES = [
182
+ [{"text": "Write a poem about spring season in French Language", }],
183
+ [{"text": "what does this chart mean?", "files": ["sales.png"]}],
184
+ [{"text": "¿Qué está escrito a mano en esta foto?", "files": ["receipt1.png"]}],
185
+ [{"text": "در مورد این عکس توضیح بده و بگو این چه فصلی می تواند باشد", "files": ["nature.jpg"]}]
186
+ ]
187
+
188
+ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
189
+ gr.HTML(TITLE)
190
+ gr.HTML(DESCRIPTION)
191
+ gr.ChatInterface(
192
+ fn=stream_chat,
193
+ multimodal=True,
194
+
195
+
196
+ textbox=chat_input,
197
+ chatbot=chatbot,
198
+ fill_height=True,
199
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
200
+ additional_inputs=[
201
+ gr.Slider(
202
+ minimum=0,
203
+ maximum=1,
204
+ step=0.1,
205
+ value=0.8,
206
+ label="Temperature",
207
+ render=False,
208
+ ),
209
+ gr.Slider(
210
+ minimum=1024,
211
+ maximum=8192,
212
+ step=1,
213
+ value=4096,
214
+ label="Max Length",
215
+ render=False,
216
+ ),
217
+ gr.Slider(
218
+ minimum=0.0,
219
+ maximum=1.0,
220
+ step=0.1,
221
+ value=1.0,
222
+ label="top_p",
223
+ render=False,
224
+ ),
225
+ gr.Slider(
226
+ minimum=1,
227
+ maximum=20,
228
+ step=1,
229
+ value=10,
230
+ label="top_k",
231
+ render=False,
232
+ ),
233
+ gr.Slider(
234
+ minimum=0.0,
235
+ maximum=2.0,
236
+ step=0.1,
237
+ value=1.0,
238
+ label="Repetition penalty",
239
+ render=False,
240
+ ),
241
+ ],
242
+ ),
243
+ gr.Examples(EXAMPLES, [chat_input])
244
+
245
+ if __name__ == "__main__":
246
+ demo.queue(api_open=False).launch(show_api=False, share=True, )#server_name="0.0.0.0", )