prithivMLmods commited on
Commit
f3b1002
·
verified ·
1 Parent(s): 2b5ad33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -345
app.py CHANGED
@@ -1,364 +1,302 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
  import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
- from PIL import Image
7
  import os
8
- import uuid
9
- import io
 
10
  from threading import Thread
11
- from reportlab.lib.pagesizes import A4
12
- from reportlab.lib.styles import getSampleStyleSheet
13
- from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
14
- from reportlab.lib.units import inch
15
- from reportlab.pdfbase import pdfmetrics
16
- from reportlab.pdfbase.ttfonts import TTFont
17
- import docx
18
- from docx.enum.text import WD_ALIGN_PARAGRAPH
19
-
20
- # Define model options
21
- MODEL_OPTIONS = {
22
- "Qwen2VL Base": "Qwen/Qwen2-VL-2B-Instruct",
23
- "Latex OCR": "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
24
- "Math Prase": "prithivMLmods/Qwen2-VL-Math-Prase-2B-Instruct",
25
- "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Preload models and processors into CUDA
29
- models = {}
30
- processors = {}
31
- for name, model_id in MODEL_OPTIONS.items():
32
- print(f"Loading {name}...")
33
- models[name] = Qwen2VLForConditionalGeneration.from_pretrained(
34
- model_id,
35
- trust_remote_code=True,
36
- torch_dtype=torch.float16
37
- ).to("cuda").eval()
38
- processors[name] = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
39
-
40
- # Get valid image extensions from PIL
41
- image_extensions = Image.registered_extensions()
42
-
43
- def identify_and_save_blob(blob_path):
44
- """Identifies if the blob is an image and saves it."""
45
- try:
46
- with open(blob_path, 'rb') as file:
47
- blob_content = file.read()
48
- try:
49
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
50
- extension = ".png" # Default to PNG for saving
51
- media_type = "image"
52
- except (IOError, SyntaxError):
53
- raise ValueError("Unsupported media type. Please upload a valid image.")
54
-
55
- filename = f"temp_{uuid.uuid4()}_media{extension}"
56
- with open(filename, "wb") as f:
57
- f.write(blob_content)
58
-
59
- return filename, media_type
60
-
61
- except FileNotFoundError:
62
- raise ValueError(f"The file {blob_path} was not found.")
63
- except Exception as e:
64
- raise ValueError(f"An error occurred while processing the file: {e}")
65
-
66
- def get_media_file(media_input):
67
  """
68
- Ensures that the media input is a file path.
69
- If it is a PIL image, it saves it temporarily and returns the file path.
 
70
  """
71
- if isinstance(media_input, str):
72
- return media_input # Already a file path
 
 
73
  else:
74
- if not isinstance(media_input, Image.Image):
75
- # Convert numpy array to PIL image if needed
76
- media_input = Image.fromarray(media_input)
77
- temp_filename = f"temp_{uuid.uuid4()}.png"
78
- media_input.save(temp_filename)
79
- return temp_filename
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @spaces.GPU
82
- def qwen_inference(model_name, media_input, text_input=None):
83
- """Handles inference for the selected model."""
84
- model = models[model_name]
85
- processor = processors[model_name]
86
-
87
- # Determine media type and obtain a file path if needed
88
- if isinstance(media_input, str):
89
- media_path = media_input
90
- if media_path.endswith(tuple(image_extensions.keys())):
91
- media_type = "image"
92
- else:
93
- try:
94
- media_path, media_type = identify_and_save_blob(media_input)
95
- except Exception as e:
96
- raise ValueError("Unsupported media type. Please upload a valid image.")
97
- else:
98
- # media_input is a PIL image (or numpy array) coming from gr.Image
99
- media_path = get_media_file(media_input)
100
- media_type = "image"
101
-
102
- messages = [
103
- {
104
- "role": "user",
105
- "content": [
106
- {
107
- "type": media_type,
108
- media_type: media_path
109
- },
110
- {"type": "text", "text": text_input},
111
- ],
112
- }
113
- ]
114
-
115
- text = processor.apply_chat_template(
116
- messages, tokenize=False, add_generation_prompt=True
117
- )
118
- image_inputs, _ = process_vision_info(messages)
119
- inputs = processor(
120
- text=[text],
121
- images=image_inputs,
122
- padding=True,
123
- return_tensors="pt",
124
- ).to("cuda")
125
-
126
- streamer = TextIteratorStreamer(
127
- processor.tokenizer, skip_prompt=True, skip_special_tokens=True
128
- )
129
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
130
-
131
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
132
- thread.start()
133
-
134
- buffer = ""
135
- for new_text in streamer:
136
- buffer += new_text
137
- # Remove <|im_end|> or similar tokens from the output
138
- buffer = buffer.replace("<|im_end|>", "")
139
- yield buffer
140
-
141
- def format_plain_text(output_text):
142
- """Formats the output text as plain text without LaTeX delimiters."""
143
- plain_text = output_text.replace("\\(", "").replace("\\)", "").replace("\\[", "").replace("\\]", "")
144
- return plain_text
145
-
146
- def generate_document(media_input, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size):
147
- """Generates a document with the input image and plain text output."""
148
- # Ensure media_input is a file path.
149
- media_path = get_media_file(media_input)
150
- plain_text = format_plain_text(output_text)
151
- if file_format == "pdf":
152
- return generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
153
- elif file_format == "docx":
154
- return generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
155
-
156
- def generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
157
- """Generates a PDF document."""
158
- filename = f"output_{uuid.uuid4()}.pdf"
159
- doc = SimpleDocTemplate(
160
- filename,
161
- pagesize=A4,
162
- rightMargin=inch,
163
- leftMargin=inch,
164
- topMargin=inch,
165
- bottomMargin=inch
166
- )
167
- styles = getSampleStyleSheet()
168
- styles["Normal"].fontName = font_choice
169
- styles["Normal"].fontSize = int(font_size)
170
- styles["Normal"].leading = int(font_size) * line_spacing
171
- styles["Normal"].alignment = {
172
- "Left": 0,
173
- "Center": 1,
174
- "Right": 2,
175
- "Justified": 4
176
- }[alignment]
177
-
178
- # Register font (assumes font files are available in a folder named "font")
179
- font_path = f"font/{font_choice}"
180
- pdfmetrics.registerFont(TTFont(font_choice, font_path))
181
-
182
- story = []
183
-
184
- # Add image with size adjustment
185
- image_sizes = {
186
- "Small": (200, 200),
187
- "Medium": (400, 400),
188
- "Large": (600, 600)
189
  }
190
- img = RLImage(media_path, width=image_sizes[image_size][0], height=image_sizes[image_size][1])
191
- story.append(img)
192
- story.append(Spacer(1, 12))
193
-
194
- # Add plain text output
195
- text_para = Paragraph(plain_text, styles["Normal"])
196
- story.append(text_para)
197
-
198
- doc.build(story)
199
- return filename
200
-
201
- def generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
202
- """Generates a DOCX document."""
203
- filename = f"output_{uuid.uuid4()}.docx"
204
- doc = docx.Document()
205
-
206
- # Add image with size adjustment
207
- image_sizes = {
208
- "Small": docx.shared.Inches(2),
209
- "Medium": docx.shared.Inches(4),
210
- "Large": docx.shared.Inches(6)
 
 
 
211
  }
212
- doc.add_picture(media_path, width=image_sizes[image_size])
213
- doc.add_paragraph()
214
-
215
- # Add plain text output
216
- paragraph = doc.add_paragraph()
217
- paragraph.paragraph_format.line_spacing = line_spacing
218
- paragraph.paragraph_format.alignment = {
219
- "Left": WD_ALIGN_PARAGRAPH.LEFT,
220
- "Center": WD_ALIGN_PARAGRAPH.CENTER,
221
- "Right": WD_ALIGN_PARAGRAPH.RIGHT,
222
- "Justified": WD_ALIGN_PARAGRAPH.JUSTIFY
223
- }[alignment]
224
- run = paragraph.add_run(plain_text)
225
- run.font.name = font_choice
226
- run.font.size = docx.shared.Pt(int(font_size))
227
-
228
- doc.save(filename)
229
- return filename
230
-
231
- # CSS for output styling
232
- css = """
233
- #output {
234
- height: 400px;
235
- overflow: auto;
236
- border: 1px solid #ccc;
237
- }
238
- .submit-btn {
239
- background-color: #cf3434 !important;
240
- color: white !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  }
242
- .submit-btn:hover {
243
- background-color: #ff2323 !important;
 
 
 
 
244
  }
245
- .download-btn {
246
- background-color: #35a6d6 !important;
247
- color: white !important;
 
 
 
 
248
  }
249
- .download-btn:hover {
250
- background-color: #22bcff !important;
 
 
251
  }
252
  """
253
 
254
- # Gradio app setup
255
- with gr.Blocks(css=css) as demo:
256
- gr.Markdown("# Qwen2VL: Compact Vision & Language Processing")
257
-
258
- with gr.Tab(label="Image Input"):
259
- with gr.Row():
260
- with gr.Column():
261
- model_choice = gr.Dropdown(
262
- label="Model Selection",
263
- choices=list(MODEL_OPTIONS.keys()),
264
- value="Latex OCR"
265
- )
266
- # Using gr.Image instead of gr.File for image upload
267
- input_media = gr.Image(
268
- label="Upload Image", type="pil"
269
- )
270
- text_input = gr.Textbox(label="Question", placeholder="Ask a question about the image...")
271
- submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
272
- with gr.Column():
273
- output_text = gr.Textbox(label="Output Text", lines=10)
274
- plain_text_output = gr.Textbox(label="Standardized Plain Text", lines=10)
275
-
276
- submit_btn.click(
277
- qwen_inference, [model_choice, input_media, text_input], [output_text]
278
- ).then(
279
- lambda output_text: format_plain_text(output_text), [output_text], [plain_text_output]
280
- )
281
-
282
- # Add examples directly usable by clicking
283
- with gr.Row():
284
- gr.Examples(
285
- examples=[
286
- ["examples/4.png", "solve the problem", "Math Prase"],
287
- ["examples/1.png", "summarize the letter", "Text Analogy Ocrtest"],
288
- ["examples/2.jpg", "Summarize the full image in detail", "Latex OCR"],
289
- ["examples/3.png", "Describe the photo", "Qwen2VL Base"],
290
- ],
291
- inputs=[input_media, text_input, model_choice],
292
- outputs=[output_text, plain_text_output],
293
- fn=lambda img, question, model: qwen_inference(model, img, question),
294
- cache_examples=False,
295
  )
296
- with gr.Row():
297
- with gr.Column():
298
- line_spacing = gr.Dropdown(
299
- choices=[0.5, 1.0, 1.15, 1.5, 2.0, 2.5, 3.0],
300
- value=1.5,
301
- label="Line Spacing"
302
- )
303
- font_size = gr.Dropdown(
304
- choices=["8", "10", "12", "14", "16", "18", "20", "22", "24"],
305
- value="12",
306
- label="Font Size"
307
  )
308
- font_choice = gr.Dropdown(
309
- choices=[
310
- "DejaVuMathTeXGyre.ttf",
311
- "FiraCode-Medium.ttf",
312
- "InputMono-Light.ttf",
313
- "JetBrainsMono-Thin.ttf",
314
- "ProggyCrossed Regular Mac.ttf",
315
- "SourceCodePro-Black.ttf",
316
- "arial.ttf",
317
- "calibri.ttf",
318
- "mukta-malar-extralight.ttf",
319
- "noto-sans-arabic-medium.ttf",
320
- "times new roman.ttf",
321
- "ANGSA.ttf",
322
- "Book-Antiqua.ttf",
323
- "CONSOLA.TTF",
324
- "COOPBL.TTF",
325
- "Rockwell-Bold.ttf",
326
- "Candara Light.TTF",
327
- "Carlito-Regular.ttf Carlito-Regular.ttf",
328
- "Castellar.ttf",
329
- "Courier New.ttf",
330
- "LSANS.TTF",
331
- "Lucida Bright Regular.ttf",
332
- "TRTempusSansITC.ttf",
333
- "Verdana.ttf",
334
- "bell-mt.ttf",
335
- "eras-itc-light.ttf",
336
- "fonnts.com-aptos-light.ttf",
337
- "georgia.ttf",
338
- "segoeuithis.ttf",
339
- "youyuan.TTF",
340
- "TfPonetoneExpanded-7BJZA.ttf",
341
- ],
342
- value="youyuan.TTF",
343
- label="Font Choice"
344
- )
345
- alignment = gr.Dropdown(
346
- choices=["Left", "Center", "Right", "Justified"],
347
- value="Justified",
348
- label="Text Alignment"
349
- )
350
- image_size = gr.Dropdown(
351
- choices=["Small", "Medium", "Large"],
352
- value="Small",
353
- label="Image Size"
354
- )
355
- file_format = gr.Radio(["pdf", "docx"], label="File Format", value="pdf")
356
- with gr.Row():
357
- get_document_btn = gr.Button(value="Get Document", elem_classes="download-btn")
358
- get_document_btn.click(
359
- generate_document,
360
- [input_media, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size],
361
- gr.File(label="Download Document")
362
- )
363
-
364
- demo.launch(debug=True)
 
1
+ import subprocess
2
+ subprocess.run(
3
+ 'pip install flash-attn==2.7.0.post2 --no-build-isolation',
4
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
+ shell=True
6
+ )
7
+ subprocess.run(
8
+ 'pip install transformers',
9
+ shell=True
10
+ )
11
+
12
+
13
  import spaces
 
 
 
 
14
  import os
15
+ import re
16
+ import logging
17
+ from typing import List
18
  from threading import Thread
19
+ import base64
20
+
21
+ import torch
22
+ import gradio as gr
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
24
+
25
+ # ----------------------------------------------------------------------
26
+ # 1. Setup Model & Tokenizer
27
+ # ----------------------------------------------------------------------
28
+ model_name = 'prithivMLmods/Raptor-X5-UIGEN' # Change as needed
29
+ use_thread = True # Generation happens in a background thread
30
+
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ model_name,
33
+ torch_dtype=torch.bfloat16,
34
+ trust_remote_code=True
35
+ ).to("cuda")
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
37
+
38
+ logging.getLogger("httpx").setLevel(logging.WARNING)
39
+ logging.basicConfig(level=logging.INFO)
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # ----------------------------------------------------------------------
43
+ # 2. Two-Phase Prompt Templates
44
+ # ----------------------------------------------------------------------
45
+ s1_inference_prompt_think_only = """<|im_start|>user
46
+ {question}<|im_end|>
47
+ <|im_start|>assistant
48
+ <|im_start|>think
49
+ """
50
 
51
+ # ----------------------------------------------------------------------
52
+ # 3. Generation Parameter Setup
53
+ # ----------------------------------------------------------------------
54
+ THINK_MAX_NEW_TOKENS = 12000
55
+ ANSWER_MAX_NEW_TOKENS = 12000
56
+
57
+ def initialize_gen_kwargs():
58
+ return {
59
+ "max_new_tokens": 1024, # default; will be overwritten per phase
60
+ "do_sample": True,
61
+ "temperature": 0.7,
62
+ "top_p": 0.9,
63
+ "repetition_penalty": 1.05,
64
+ # "eos_token_id": model.generation_config.eos_token_id, # Removed to avoid premature stopping
65
+ "pad_token_id": tokenizer.pad_token_id,
66
+ "use_cache": True,
67
+ "streamer": None # dynamically added
68
+ }
69
+
70
+ # ----------------------------------------------------------------------
71
+ # 4. Helper to submit chat
72
+ # ----------------------------------------------------------------------
73
+ def submit_chat(chatbot, text_input):
74
+ if not text_input.strip():
75
+ return chatbot, ""
76
+ response = ""
77
+ chatbot.append((text_input, response))
78
+ return chatbot, ""
79
+
80
+ # ----------------------------------------------------------------------
81
+ # 5. Artifacts Handling
82
+ # We parse code from the final answer and display it in an iframe
83
+ # ----------------------------------------------------------------------
84
+ def extract_html_code_block(text: str) -> str:
 
 
 
 
 
85
  """
86
+ Look for a ```html ... ``` block in the text.
87
+ If found, return only that block content.
88
+ Otherwise, return the entire text.
89
  """
90
+ pattern = r'```html\s*(.*?)\s*```'
91
+ match = re.search(pattern, text, re.DOTALL)
92
+ if match:
93
+ return match.group(1).strip()
94
  else:
95
+ return text.strip()
 
 
 
 
 
96
 
97
+ def send_to_sandbox(html_code: str) -> str:
98
+ """
99
+ Convert the code to a data URI iframe so it can be rendered
100
+ inside Gradio HTML component.
101
+ """
102
+ encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
103
+ data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
104
+ return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
105
+
106
+ # ----------------------------------------------------------------------
107
+ # 6. The Two-Phase Streaming Inference
108
+ # - Phase 1: "think" (chain-of-thought)
109
+ # - Phase 2: "answer"
110
+ # ----------------------------------------------------------------------
111
  @spaces.GPU
112
+ def ovis_chat(chatbot: List[List[str]]):
113
+ # Phase 1: chain-of-thought
114
+ last_query = chatbot[-1][0]
115
+ formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
116
+ input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
117
+ attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device)
118
+
119
+ think_inputs = {
120
+ "input_ids": input_ids_think,
121
+ "attention_mask": attention_mask_think
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  }
123
+ gen_kwargs_think = initialize_gen_kwargs()
124
+ gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS
125
+ think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
126
+ gen_kwargs_think["streamer"] = think_streamer
127
+
128
+ full_think = ""
129
+ with torch.inference_mode():
130
+ thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think))
131
+ thread_think.start()
132
+ for new_text in think_streamer:
133
+ full_think += new_text
134
+ display_text = f"<|im_start|>think\n{full_think.strip()}"
135
+ chatbot[-1][1] = display_text
136
+ yield chatbot, "" # second return is artifact placeholder
137
+ thread_think.join()
138
+
139
+ # Phase 2: answer
140
+ new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n"
141
+ input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device)
142
+ attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device)
143
+
144
+ answer_inputs = {
145
+ "input_ids": input_ids_answer,
146
+ "attention_mask": attention_mask_answer
147
  }
148
+ gen_kwargs_answer = initialize_gen_kwargs()
149
+ gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS
150
+ answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
151
+ gen_kwargs_answer["streamer"] = answer_streamer
152
+
153
+ full_answer = ""
154
+ with torch.inference_mode():
155
+ thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer))
156
+ thread_answer.start()
157
+ for new_text in answer_streamer:
158
+ full_answer += new_text
159
+ display_text = (
160
+ f"<|im_start|>think\n{full_think.strip()}\n\n"
161
+ f"<|im_start|>answer\n{full_answer.strip()}"
162
+ )
163
+ chatbot[-1][1] = display_text
164
+ yield chatbot, ""
165
+ thread_answer.join()
166
+
167
+ log_conversation(chatbot)
168
+
169
+ # Once final answer is complete, parse out HTML code block and
170
+ # return it as an artifact (iframe).
171
+ html_code = extract_html_code_block(full_answer)
172
+ sandbox_iframe = send_to_sandbox(html_code)
173
+ yield chatbot, sandbox_iframe
174
+
175
+ # ----------------------------------------------------------------------
176
+ # 7. Logging and Clearing
177
+ # ----------------------------------------------------------------------
178
+ def log_conversation(chatbot: List[List[str]]):
179
+ logger.info("[CONVERSATION]")
180
+ for i, (query, response) in enumerate(chatbot, 1):
181
+ logger.info(f"Q{i}: {query}\nA{i}: {response}")
182
+
183
+ def clear_chat():
184
+ return [], "", ""
185
+
186
+ # ----------------------------------------------------------------------
187
+ # 8. Gradio UI Setup
188
+ # ----------------------------------------------------------------------
189
+ css_code = """
190
+ .left_header {
191
+ display: flex;
192
+ flex-direction: column;
193
+ justify-content: center;
194
+ align-items: center;
195
+ }
196
+
197
+ .right_panel {
198
+ margin-top: 16px;
199
+ border: 1px solid #BFBFC4;
200
+ border-radius: 8px;
201
+ overflow: hidden;
202
+ }
203
+
204
+ .render_header {
205
+ height: 30px;
206
+ width: 100%;
207
+ padding: 5px 16px;
208
+ background-color: #f5f5f5;
209
+ }
210
+
211
+ .header_btn {
212
+ display: inline-block;
213
+ height: 10px;
214
+ width: 10px;
215
+ border-radius: 50%;
216
+ margin-right: 4px;
217
+ }
218
+
219
+ .render_header > .header_btn:nth-child(1) {
220
+ background-color: #f5222d;
221
  }
222
+
223
+ .render_header > .header_btn:nth-child(2) {
224
+ background-color: #faad14;
225
+ }
226
+ .render_header > .header_btn:nth-child(3) {
227
+ background-color: #52c41a;
228
  }
229
+
230
+ .right_content {
231
+ height: 920px;
232
+ display: flex;
233
+ flex-direction: column;
234
+ justify-content: center;
235
+ align-items: center;
236
  }
237
+
238
+ .html_content {
239
+ width: 100%;
240
+ height: 920px;
241
  }
242
  """
243
 
244
+ svg_content = """
245
+ <svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg">
246
+ <circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/>
247
+ <path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/>
248
+ <path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/>
249
+ <path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/>
250
+ </svg>
251
+ """
252
+
253
+ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
254
+ gr.HTML(f"""
255
+ <div class="left_header" style="margin-bottom: 20px;">
256
+ {svg_content}
257
+ <h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1>
258
+ <p>(Two-phase chain-of-thought with artifact extraction)</p>
259
+ </div>
260
+ """)
261
+
262
+ with gr.Row():
263
+ with gr.Column(scale=4):
264
+ chatbot = gr.Chatbot(
265
+ label="Chat",
266
+ height=520,
267
+ show_copy_button=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
269
+ with gr.Row():
270
+ text_input = gr.Textbox(
271
+ label="Prompt",
272
+ placeholder="Enter your query...",
273
+ lines=1
 
 
 
 
 
 
274
  )
275
+ with gr.Row():
276
+ submit_btn = gr.Button("Send", variant="primary")
277
+ clear_btn = gr.Button("Clear", variant="secondary")
278
+ with gr.Column(scale=6):
279
+ gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
280
+ artifact_html = gr.HTML(
281
+ value="",
282
+ elem_classes="html_content"
283
+ )
284
+
285
+ submit_btn.click(
286
+ submit_chat, [chatbot, text_input], [chatbot, text_input]
287
+ ).then(
288
+ ovis_chat, [chatbot], [chatbot, artifact_html]
289
+ )
290
+
291
+ text_input.submit(
292
+ submit_chat, [chatbot, text_input], [chatbot, text_input]
293
+ ).then(
294
+ ovis_chat, [chatbot], [chatbot, artifact_html]
295
+ )
296
+
297
+ clear_btn.click(
298
+ clear_chat,
299
+ outputs=[chatbot, text_input, artifact_html]
300
+ )
301
+
302
+ demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True)