Tonic commited on
Commit
83c27e6
1 Parent(s): c8704f4
Files changed (1) hide show
  1. app.py +36 -19
app.py CHANGED
@@ -8,21 +8,17 @@ import requests
8
  from huggingface_hub import login
9
  import torch
10
  import torch.nn.functional as F
11
- import spaces
12
  import json
13
- import gradio as gr
14
  from huggingface_hub import snapshot_download
15
- import os
16
- # from loadimg import load_img
17
- import traceback
18
 
19
  login(os.environ.get("HUGGINGFACE_TOKEN"))
 
20
 
21
  repo_id = "mistralai/Pixtral-12B-2409"
22
- sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
23
- max_tokens_per_img = 4096
24
- max_img_per_msg = 5
25
-
26
 
27
  title = "# **WIP / DEMO** 🙋🏻‍♂️Welcome to Tonic's Pixtral Model Demo"
28
  description = """
@@ -40,9 +36,15 @@ with open(f'{model_path}/tekken.json', 'r') as f:
40
 
41
  model_name = "mistralai/Pixtral-12B-2409"
42
 
43
- sampling_params = SamplingParams(max_tokens=8192)
 
 
 
 
 
44
 
45
- llm = LLM(model=model_name, tokenizer_mode="mistral")
 
46
 
47
  def encode_image(image: Image.Image, image_format="PNG") -> str:
48
  im_file = BytesIO()
@@ -51,11 +53,13 @@ def encode_image(image: Image.Image, image_format="PNG") -> str:
51
  im_64 = base64.b64encode(im_bytes).decode("utf-8")
52
  return im_64
53
 
54
- def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
55
  if llm is None:
56
  return "Error: LLM initialization failed. Please try again later."
57
 
58
  try:
 
 
59
  image = Image.open(BytesIO(requests.get(image_url).content))
60
  image = image.resize((3844, 2408))
61
  new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"
@@ -68,16 +72,19 @@ def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
68
  ]
69
 
70
  outputs = llm.chat(messages, sampling_params=sampling_params)
71
-
72
  return outputs[0].outputs[0].text
73
  except Exception as e:
 
74
  return f"Error during inference: {e}"
75
 
76
- def compare_images(image1_url, image2_url, prompt, progress=gr.Progress(track_tqdm=True)):
77
  if llm is None:
78
  return "Error: LLM initialization failed. Please try again later."
79
 
80
  try:
 
 
81
  image1 = Image.open(BytesIO(requests.get(image1_url).content))
82
  image2 = Image.open(BytesIO(requests.get(image2_url).content))
83
  image1 = image1.resize((3844, 2408))
@@ -97,9 +104,10 @@ def compare_images(image1_url, image2_url, prompt, progress=gr.Progress(track_tq
97
  ]
98
 
99
  outputs = llm.chat(messages, sampling_params=sampling_params)
100
-
101
  return outputs[0].outputs[0].text
102
  except Exception as e:
 
103
  return f"Error during image comparison: {e}"
104
 
105
  def calculate_image_similarity(image1_url, image2_url):
@@ -120,9 +128,10 @@ def calculate_image_similarity(image1_url, image2_url):
120
  embedding2 = llm.model.vision_encoder([image2_tensor])
121
 
122
  similarity = F.cosine_similarity(embedding1.mean(dim=0), embedding2.mean(dim=0), dim=0).item()
123
-
124
  return similarity
125
  except Exception as e:
 
126
  return f"Error during image similarity calculation: {e}"
127
 
128
  with gr.Blocks() as demo:
@@ -137,10 +146,12 @@ with gr.Blocks() as demo:
137
  1. For Image-to-Text Generation:
138
  - Enter the URL of an image
139
  - Provide a prompt describing what you want to know about the image
 
140
  - Click "Generate" to get the model's response
141
  2. For Image Comparison:
142
  - Enter URLs for two images you want to compare
143
  - Provide a prompt asking about the comparison
 
144
  - Click "Compare" to get the model's analysis
145
  3. For Image Similarity:
146
  - Enter URLs for two images you want to compare
@@ -153,20 +164,26 @@ with gr.Blocks() as demo:
153
  with gr.Row():
154
  image_url = gr.Text(label="Image URL")
155
  prompt = gr.Text(label="Prompt")
 
 
 
156
  generate_button = gr.Button("Generate")
157
  output = gr.Text(label="Generated Text")
158
 
159
- generate_button.click(infer, inputs=[image_url, prompt], outputs=output)
160
 
161
  with gr.TabItem("Image Comparison"):
162
  with gr.Row():
163
  image1_url = gr.Text(label="Image 1 URL")
164
  image2_url = gr.Text(label="Image 2 URL")
165
  comparison_prompt = gr.Text(label="Comparison Prompt")
 
 
 
166
  compare_button = gr.Button("Compare")
167
  comparison_output = gr.Text(label="Comparison Result")
168
 
169
- compare_button.click(compare_images, inputs=[image1_url, image2_url, comparison_prompt], outputs=comparison_output)
170
 
171
  with gr.TabItem("Image Similarity"):
172
  with gr.Row():
@@ -187,4 +204,4 @@ with gr.Blocks() as demo:
187
  gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}")
188
 
189
  if __name__ == "__main__":
190
- demo.launch()
 
8
  from huggingface_hub import login
9
  import torch
10
  import torch.nn.functional as F
11
+ # import spaces
12
  import json
 
13
  from huggingface_hub import snapshot_download
14
+ # import traceback
 
 
15
 
16
  login(os.environ.get("HUGGINGFACE_TOKEN"))
17
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:80"
18
 
19
  repo_id = "mistralai/Pixtral-12B-2409"
20
+ max_tokens_per_img = 2048
21
+ max_img_per_msg = 2
 
 
22
 
23
  title = "# **WIP / DEMO** 🙋🏻‍♂️Welcome to Tonic's Pixtral Model Demo"
24
  description = """
 
36
 
37
  model_name = "mistralai/Pixtral-12B-2409"
38
 
39
+ llm = LLM(
40
+ model=model_name,
41
+ tokenizer_mode="mistral",
42
+ max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
43
+ dtype="float16"
44
+ )
45
 
46
+ def clear_cuda_cache():
47
+ torch.cuda.empty_cache()
48
 
49
  def encode_image(image: Image.Image, image_format="PNG") -> str:
50
  im_file = BytesIO()
 
53
  im_64 = base64.b64encode(im_bytes).decode("utf-8")
54
  return im_64
55
 
56
+ def infer(image_url, prompt, temperature, max_tokens, progress=gr.Progress(track_tqdm=True)):
57
  if llm is None:
58
  return "Error: LLM initialization failed. Please try again later."
59
 
60
  try:
61
+ sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temperature)
62
+
63
  image = Image.open(BytesIO(requests.get(image_url).content))
64
  image = image.resize((3844, 2408))
65
  new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"
 
72
  ]
73
 
74
  outputs = llm.chat(messages, sampling_params=sampling_params)
75
+ clear_cuda_cache()
76
  return outputs[0].outputs[0].text
77
  except Exception as e:
78
+ clear_cuda_cache()
79
  return f"Error during inference: {e}"
80
 
81
+ def compare_images(image1_url, image2_url, prompt, temperature, max_tokens, progress=gr.Progress(track_tqdm=True)):
82
  if llm is None:
83
  return "Error: LLM initialization failed. Please try again later."
84
 
85
  try:
86
+ sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temperature)
87
+
88
  image1 = Image.open(BytesIO(requests.get(image1_url).content))
89
  image2 = Image.open(BytesIO(requests.get(image2_url).content))
90
  image1 = image1.resize((3844, 2408))
 
104
  ]
105
 
106
  outputs = llm.chat(messages, sampling_params=sampling_params)
107
+ clear_cuda_cache()
108
  return outputs[0].outputs[0].text
109
  except Exception as e:
110
+ clear_cuda_cache()
111
  return f"Error during image comparison: {e}"
112
 
113
  def calculate_image_similarity(image1_url, image2_url):
 
128
  embedding2 = llm.model.vision_encoder([image2_tensor])
129
 
130
  similarity = F.cosine_similarity(embedding1.mean(dim=0), embedding2.mean(dim=0), dim=0).item()
131
+ clear_cuda_cache()
132
  return similarity
133
  except Exception as e:
134
+ clear_cuda_cache()
135
  return f"Error during image similarity calculation: {e}"
136
 
137
  with gr.Blocks() as demo:
 
146
  1. For Image-to-Text Generation:
147
  - Enter the URL of an image
148
  - Provide a prompt describing what you want to know about the image
149
+ - Adjust the temperature and max tokens
150
  - Click "Generate" to get the model's response
151
  2. For Image Comparison:
152
  - Enter URLs for two images you want to compare
153
  - Provide a prompt asking about the comparison
154
+ - Adjust the temperature and max tokens
155
  - Click "Compare" to get the model's analysis
156
  3. For Image Similarity:
157
  - Enter URLs for two images you want to compare
 
164
  with gr.Row():
165
  image_url = gr.Text(label="Image URL")
166
  prompt = gr.Text(label="Prompt")
167
+ with gr.Row():
168
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature")
169
+ max_tokens = gr.Number(value=4096, label="Max Tokens")
170
  generate_button = gr.Button("Generate")
171
  output = gr.Text(label="Generated Text")
172
 
173
+ generate_button.click(infer, inputs=[image_url, prompt, temperature, max_tokens], outputs=output)
174
 
175
  with gr.TabItem("Image Comparison"):
176
  with gr.Row():
177
  image1_url = gr.Text(label="Image 1 URL")
178
  image2_url = gr.Text(label="Image 2 URL")
179
  comparison_prompt = gr.Text(label="Comparison Prompt")
180
+ with gr.Row():
181
+ comparison_temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature")
182
+ comparison_max_tokens = gr.Number(value=4096, label="Max Tokens")
183
  compare_button = gr.Button("Compare")
184
  comparison_output = gr.Text(label="Comparison Result")
185
 
186
+ compare_button.click(compare_images, inputs=[image1_url, image2_url, comparison_prompt, comparison_temperature, comparison_max_tokens], outputs=comparison_output)
187
 
188
  with gr.TabItem("Image Similarity"):
189
  with gr.Row():
 
204
  gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}")
205
 
206
  if __name__ == "__main__":
207
+ demo.launch()