LukasHug commited on
Commit
bd3c47a
·
verified ·
1 Parent(s): b4cff2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -11,10 +11,10 @@ import gradio as gr
11
  import torch
12
  from PIL import Image
13
  from transformers import (
14
- AutoModelForCausalLM,
15
- AutoProcessor,
16
  AutoTokenizer,
17
- Qwen2_5_VLForConditionalGeneration
 
18
  )
19
 
20
  from taxonomy import policy_v1
@@ -113,7 +113,7 @@ default_conversation = Conversation()
113
  tokenizer = None
114
  model = None
115
  processor = None
116
- context_len = 2048
117
 
118
  # Helper functions
119
  def clear_conv(conv):
@@ -150,7 +150,7 @@ def load_model(model_path):
150
 
151
  # Otherwise assume it's a LlavaGuard model
152
  else:
153
- model = AutoModelForCausalLM.from_pretrained(
154
  model_path,
155
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
156
  device_map="auto" if torch.cuda.is_available() else None,
@@ -159,7 +159,7 @@ def load_model(model_path):
159
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
160
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
161
 
162
- context_len = getattr(model.config, "max_position_embeddings", 2048)
163
  logger.info(f"Model {model_path} loaded successfully")
164
  return True
165
 
@@ -169,10 +169,10 @@ def load_model(model_path):
169
 
170
  def get_model_list():
171
  models = [
172
- 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
173
  'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
 
 
174
  'AIML-TUDA/QwenGuard-v1.2-7B',
175
- 'AIML-TUDA/QwenGuard-v1.2-3B'
176
  ]
177
  return models
178
 
@@ -238,12 +238,22 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
238
 
239
  # Otherwise assume it's a LlavaGuard model
240
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # Process input for LlavaGuard models
242
- inputs = processor(
243
- prompt,
244
- images=image,
245
- return_tensors="pt"
246
- )
247
 
248
  # Move to GPU if available
249
  if torch.cuda.is_available():
 
11
  import torch
12
  from PIL import Image
13
  from transformers import (
14
+ AutoProcessor,
 
15
  AutoTokenizer,
16
+ Qwen2_5_VLForConditionalGeneration,
17
+ LlavaOnevisionForConditionalGeneration
18
  )
19
 
20
  from taxonomy import policy_v1
 
113
  tokenizer = None
114
  model = None
115
  processor = None
116
+ context_len = 8048
117
 
118
  # Helper functions
119
  def clear_conv(conv):
 
150
 
151
  # Otherwise assume it's a LlavaGuard model
152
  else:
153
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
154
  model_path,
155
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
156
  device_map="auto" if torch.cuda.is_available() else None,
 
159
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
160
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
161
 
162
+ context_len = getattr(model.config, "max_position_embeddings", 8048)
163
  logger.info(f"Model {model_path} loaded successfully")
164
  return True
165
 
 
169
 
170
  def get_model_list():
171
  models = [
 
172
  'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
173
+ 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
174
+ 'AIML-TUDA/QwenGuard-v1.2-3B',
175
  'AIML-TUDA/QwenGuard-v1.2-7B',
 
176
  ]
177
  return models
178
 
 
238
 
239
  # Otherwise assume it's a LlavaGuard model
240
  else:
241
+ conversation = [
242
+ {
243
+ "role": "user",
244
+ "content": [
245
+ {"type": "image"},
246
+ {"type": "text", "text": prompt},
247
+ ],
248
+ },
249
+ ]
250
+
251
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
252
+
253
+
254
  # Process input for LlavaGuard models
255
+ inputs = processor(text=text_prompt, images=image, return_tensors="pt")
256
+
 
 
 
257
 
258
  # Move to GPU if available
259
  if torch.cuda.is_available():