Update app.py
Browse files
app.py
CHANGED
@@ -1,54 +1,73 @@
|
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
temperature,
|
17 |
-
top_p,
|
18 |
-
):
|
19 |
-
messages = [{"role": "system", "content": system_message}]
|
20 |
-
|
21 |
-
for user_msg, bot_reply in history:
|
22 |
-
messages.append({"role": "user", "content": user_msg})
|
23 |
-
if bot_reply:
|
24 |
-
messages.append({"role": "assistant", "content": bot_reply})
|
25 |
-
|
26 |
-
messages.append({"role": "user", "content": message})
|
27 |
-
|
28 |
-
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
29 |
-
model_inputs = tokenizer([text], return_tensors="pt").to("cuda")
|
30 |
-
|
31 |
-
generated_ids = model.generate(
|
32 |
-
**model_inputs,
|
33 |
-
# max_new_tokens=max_tokens,
|
34 |
-
# temperature=temperature,
|
35 |
-
# top_p=top_p,
|
36 |
-
)
|
37 |
-
|
38 |
-
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
39 |
-
return response
|
40 |
-
|
41 |
-
# Load model and tokenizer
|
42 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
-
model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
demo = gr.ChatInterface(
|
46 |
-
|
47 |
additional_inputs=[
|
48 |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
],
|
53 |
)
|
54 |
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
import tempfile
|
5 |
+
import secrets
|
6 |
+
from pathlib import Path
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, AutoProcessor
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
# Load Vision-Language Model
|
11 |
+
vl_model_name = "Salesforce/blip-image-captioning-large"
|
12 |
+
vl_model = BlipForConditionalGeneration.from_pretrained(vl_model_name)
|
13 |
+
vl_processor = AutoProcessor.from_pretrained(vl_model_name)
|
14 |
+
|
15 |
+
# Load Text Model
|
16 |
+
model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
+
|
21 |
+
math_messages = []
|
22 |
+
|
23 |
+
def process_image(image, shouldConvert=False):
|
24 |
+
global math_messages
|
25 |
+
math_messages = [] # Reset when uploading an image
|
26 |
+
|
27 |
+
if shouldConvert:
|
28 |
+
new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
|
29 |
+
new_img.paste(image, (0, 0), mask=image)
|
30 |
+
image = new_img
|
31 |
+
|
32 |
+
# Convert the image to tensor
|
33 |
+
inputs = vl_processor(images=image, return_tensors="pt")
|
34 |
+
output = vl_model.generate(**inputs)
|
35 |
+
description = vl_processor.batch_decode(output, skip_special_tokens=True)[0]
|
36 |
+
|
37 |
+
return f"Math-related content detected: {description}"
|
38 |
+
|
39 |
+
def get_math_response(image_description, user_question):
|
40 |
+
global math_messages
|
41 |
+
if not math_messages:
|
42 |
+
math_messages.append({'role': 'system', 'content': 'You are a helpful math assistant.'})
|
43 |
+
math_messages = math_messages[:1]
|
44 |
+
content = f'Image description: {image_description}\n\n' if image_description else ''
|
45 |
+
query = f"{content}User question: {user_question}"
|
46 |
+
math_messages.append({'role': 'user', 'content': query})
|
47 |
+
model_inputs = tokenizer(query, return_tensors="pt").to(device)
|
48 |
+
output = model.generate(**model_inputs, max_new_tokens=512)
|
49 |
+
answer = tokenizer.decode(output[0], skip_special_tokens=True)
|
50 |
+
yield answer.replace("\\", "\\\\")
|
51 |
+
math_messages.append({'role': 'assistant', 'content': answer})
|
52 |
+
|
53 |
+
def math_chat_bot(image, sketchpad, question, state):
|
54 |
+
current_tab_index = state["tab_index"]
|
55 |
+
image_description = None
|
56 |
+
if current_tab_index == 0:
|
57 |
+
if image is not None:
|
58 |
+
image_description = process_image(image)
|
59 |
+
elif current_tab_index == 1:
|
60 |
+
if sketchpad and sketchpad["composite"]:
|
61 |
+
image_description = process_image(sketchpad["composite"], True)
|
62 |
+
yield from get_math_response(image_description, question)
|
63 |
|
64 |
demo = gr.ChatInterface(
|
65 |
+
math_chat_bot,
|
66 |
additional_inputs=[
|
67 |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
68 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
69 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
70 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
71 |
],
|
72 |
)
|
73 |
|