Dileep7729 commited on
Commit
988ceee
·
verified ·
1 Parent(s): f3de939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -19
app.py CHANGED
@@ -2,20 +2,30 @@ import gradio as gr
2
  from transformers import CLIPModel, CLIPProcessor
3
  from PIL import Image
4
 
5
- # Load the model and processor
6
  model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
7
- print("Loading the model and processor...")
8
 
9
  try:
10
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
11
  processor = CLIPProcessor.from_pretrained(model_name)
12
  print("Model and processor loaded successfully.")
13
  except Exception as e:
14
- print(f"Error loading the model or processor: {e}")
15
- raise RuntimeError(f"Failed to load model: {e}")
16
 
17
- # Define the inference function
 
18
  def classify_image(image):
 
 
 
 
 
 
 
 
 
19
  try:
20
  print("Starting image classification...")
21
 
@@ -23,52 +33,56 @@ def classify_image(image):
23
  if image is None:
24
  raise ValueError("No image provided. Please upload a valid image.")
25
  if not hasattr(image, "convert"):
26
- raise ValueError("Uploaded file is not a valid image format.")
27
 
28
- # Define categories
29
  categories = ["safe", "unsafe"]
30
- print(f"Categories: {categories}")
31
 
32
- # Process the image
33
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
34
- print(f"Processed inputs: {inputs}")
35
 
36
  # Perform inference
37
  outputs = model(**inputs)
38
- print(f"Model outputs: {outputs}")
39
 
40
  # Calculate probabilities
41
  logits_per_image = outputs.logits_per_image
42
  probs = logits_per_image.softmax(dim=1)
43
  print(f"Probabilities: {probs}")
44
 
45
- # Extract probabilities
46
  safe_prob = probs[0][0].item() * 100
47
  unsafe_prob = probs[0][1].item() * 100
48
  print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%")
49
 
50
  # Determine the predicted category
51
  predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
 
 
 
52
  return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}
53
 
54
  except Exception as e:
55
  print(f"Error during classification: {e}")
56
- return f"Error: {str(e)}", {}
 
57
 
58
- # Gradio interface
59
  iface = gr.Interface(
60
  fn=classify_image,
61
- inputs=gr.Image(type="pil"),
62
  outputs=[
63
- gr.Textbox(label="Predicted Category"),
64
- gr.Label(label="Probabilities"),
65
  ],
66
  title="Content Safety Classification",
67
- description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
68
  )
69
 
70
  if __name__ == "__main__":
71
- print("Launching Gradio interface...")
72
  iface.launch()
73
 
74
 
@@ -93,5 +107,6 @@ if __name__ == "__main__":
93
 
94
 
95
 
 
96
 
97
 
 
2
  from transformers import CLIPModel, CLIPProcessor
3
  from PIL import Image
4
 
5
+ # Load Model and Processor
6
  model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
7
+ print("Initializing the model and processor...")
8
 
9
  try:
10
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
11
  processor = CLIPProcessor.from_pretrained(model_name)
12
  print("Model and processor loaded successfully.")
13
  except Exception as e:
14
+ print(f"Error loading model or processor: {e}")
15
+ raise RuntimeError(f"Failed to load the model: {e}")
16
 
17
+
18
+ # Inference Function
19
  def classify_image(image):
20
+ """
21
+ Classifies an image as 'safe' or 'unsafe' using the CLIP model.
22
+
23
+ Args:
24
+ image (PIL.Image.Image): Uploaded image.
25
+
26
+ Returns:
27
+ Tuple: Predicted category and probabilities for "safe" and "unsafe".
28
+ """
29
  try:
30
  print("Starting image classification...")
31
 
 
33
  if image is None:
34
  raise ValueError("No image provided. Please upload a valid image.")
35
  if not hasattr(image, "convert"):
36
+ raise ValueError("Uploaded file is not a valid image format (JPEG, PNG, etc.).")
37
 
38
+ # Define classification categories
39
  categories = ["safe", "unsafe"]
40
+ print(f"Using categories: {categories}")
41
 
42
+ # Process image using CLIPProcessor
43
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
44
+ print("Image processed successfully.")
45
 
46
  # Perform inference
47
  outputs = model(**inputs)
48
+ print("Inference completed successfully.")
49
 
50
  # Calculate probabilities
51
  logits_per_image = outputs.logits_per_image
52
  probs = logits_per_image.softmax(dim=1)
53
  print(f"Probabilities: {probs}")
54
 
55
+ # Extract probabilities for each category
56
  safe_prob = probs[0][0].item() * 100
57
  unsafe_prob = probs[0][1].item() * 100
58
  print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%")
59
 
60
  # Determine the predicted category
61
  predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
62
+ print(f"Predicted Category: {predicted_category}")
63
+
64
+ # Return category and probabilities
65
  return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}
66
 
67
  except Exception as e:
68
  print(f"Error during classification: {e}")
69
+ return "Error", {"safe": "N/A", "unsafe": "N/A"}
70
+
71
 
72
+ # Gradio Interface
73
  iface = gr.Interface(
74
  fn=classify_image,
75
+ inputs=gr.Image(type="pil"), # Accept image input
76
  outputs=[
77
+ gr.Textbox(label="Predicted Category"), # Predicted category
78
+ gr.Label(label="Probabilities"), # Probabilities as progress bars
79
  ],
80
  title="Content Safety Classification",
81
+ description="Upload an image to classify it as 'safe' or 'unsafe' with probabilities.",
82
  )
83
 
84
  if __name__ == "__main__":
85
+ print("Launching the Gradio interface...")
86
  iface.launch()
87
 
88
 
 
107
 
108
 
109
 
110
+
111
 
112