Dileep7729 commited on
Commit
e99084c
·
verified ·
1 Parent(s): 4d41f6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -56
app.py CHANGED
@@ -1,82 +1,39 @@
1
  import gradio as gr
2
  from transformers import CLIPModel, CLIPProcessor
3
  from PIL import Image
4
- import requests
5
 
6
- # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
7
  model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
8
-
9
  print("Initializing the application...")
10
 
11
  try:
12
- print("Loading the model from Hugging Face Model Hub...")
13
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
14
  processor = CLIPProcessor.from_pretrained(model_name)
15
  print("Model and processor loaded successfully.")
16
  except Exception as e:
17
  print(f"Error loading the model or processor: {e}")
18
- raise RuntimeError(f"Failed to load model: {e}")
19
-
20
- # Step 2: Minimal Test Case to Verify Model and Processor
21
- try:
22
- print("Running a minimal test case with the model...")
23
-
24
- # Test Image URL
25
- url = "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"
26
- image = Image.open(requests.get(url, stream=True).raw)
27
-
28
- # Define test categories
29
- test_categories = ["safe", "unsafe"]
30
-
31
- # Process the image
32
- test_inputs = processor(text=test_categories, images=image, return_tensors="pt", padding=True)
33
- print(f"Test inputs processed: {test_inputs}")
34
 
35
- # Perform inference
36
- test_outputs = model(**test_inputs)
37
- print(f"Test outputs: {test_outputs}")
38
-
39
- # Check probabilities
40
- test_logits = test_outputs.logits_per_image
41
- test_probs = test_logits.softmax(dim=1)
42
- print(f"Test probabilities: {test_probs}")
43
-
44
- except Exception as e:
45
- print(f"Error during the minimal test case: {e}")
46
- raise RuntimeError(f"Test case failed: {e}")
47
-
48
- # Step 3: Define the Inference Function
49
  def classify_image(image):
50
- """
51
- Classify an image as 'safe' or 'unsafe' and return probabilities.
52
-
53
- Args:
54
- image (PIL.Image.Image): Uploaded image.
55
-
56
- Returns:
57
- str: Predicted category.
58
- dict: Probabilities for "safe" and "unsafe".
59
- """
60
  try:
61
  print("Starting image classification...")
62
 
63
- # Check if the image is valid
64
  if image is None:
65
  raise ValueError("No image provided. Please upload a valid image.")
66
  if not hasattr(image, "convert"):
67
  raise ValueError("Uploaded file is not a valid image format.")
68
 
69
- # Define main categories
70
  categories = ["safe", "unsafe"]
71
  print(f"Categories: {categories}")
72
 
73
  # Process the image
74
- print("Processing the image with the processor...")
75
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
76
  print(f"Processed inputs: {inputs}")
77
 
78
  # Perform inference
79
- print("Running model inference...")
80
  outputs = model(**inputs)
81
  print(f"Model outputs: {outputs}")
82
 
@@ -85,34 +42,31 @@ def classify_image(image):
85
  probs = logits_per_image.softmax(dim=1)
86
  print(f"Probabilities: {probs}")
87
 
88
- # Extract probabilities for each category
89
  safe_prob = probs[0][0].item() * 100
90
  unsafe_prob = probs[0][1].item() * 100
 
91
 
92
  # Determine the predicted category
93
  predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
94
- print(f"Predicted category: {predicted_category}")
95
-
96
- # Return the predicted category and probabilities
97
  return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}
98
 
99
  except Exception as e:
100
  print(f"Error during classification: {e}")
101
  return f"Error: {str(e)}", {}
102
 
103
- # Step 4: Set Up Gradio Interface
104
  iface = gr.Interface(
105
  fn=classify_image,
106
  inputs=gr.Image(type="pil"),
107
  outputs=[
108
- gr.Textbox(label="Predicted Category"), # Display the predicted category prominently
109
- gr.Label(label="Probabilities"), # Display probabilities with a progress bar
110
  ],
111
  title="Content Safety Classification",
112
  description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
113
  )
114
 
115
- # Step 5: Launch Gradio Interface
116
  if __name__ == "__main__":
117
  print("Launching Gradio interface...")
118
  iface.launch()
@@ -137,5 +91,6 @@ if __name__ == "__main__":
137
 
138
 
139
 
 
140
 
141
 
 
1
  import gradio as gr
2
  from transformers import CLIPModel, CLIPProcessor
3
  from PIL import Image
 
4
 
5
+ # Load the model and processor
6
  model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
 
7
  print("Initializing the application...")
8
 
9
  try:
 
10
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
11
  processor = CLIPProcessor.from_pretrained(model_name)
12
  print("Model and processor loaded successfully.")
13
  except Exception as e:
14
  print(f"Error loading the model or processor: {e}")
15
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Define the inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def classify_image(image):
 
 
 
 
 
 
 
 
 
 
19
  try:
20
  print("Starting image classification...")
21
 
22
+ # Validate image input
23
  if image is None:
24
  raise ValueError("No image provided. Please upload a valid image.")
25
  if not hasattr(image, "convert"):
26
  raise ValueError("Uploaded file is not a valid image format.")
27
 
28
+ # Define categories
29
  categories = ["safe", "unsafe"]
30
  print(f"Categories: {categories}")
31
 
32
  # Process the image
 
33
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
34
  print(f"Processed inputs: {inputs}")
35
 
36
  # Perform inference
 
37
  outputs = model(**inputs)
38
  print(f"Model outputs: {outputs}")
39
 
 
42
  probs = logits_per_image.softmax(dim=1)
43
  print(f"Probabilities: {probs}")
44
 
45
+ # Extract probabilities
46
  safe_prob = probs[0][0].item() * 100
47
  unsafe_prob = probs[0][1].item() * 100
48
+ print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%")
49
 
50
  # Determine the predicted category
51
  predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
 
 
 
52
  return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}
53
 
54
  except Exception as e:
55
  print(f"Error during classification: {e}")
56
  return f"Error: {str(e)}", {}
57
 
58
+ # Gradio interface
59
  iface = gr.Interface(
60
  fn=classify_image,
61
  inputs=gr.Image(type="pil"),
62
  outputs=[
63
+ gr.Textbox(label="Predicted Category"),
64
+ gr.Label(label="Probabilities"),
65
  ],
66
  title="Content Safety Classification",
67
  description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
68
  )
69
 
 
70
  if __name__ == "__main__":
71
  print("Launching Gradio interface...")
72
  iface.launch()
 
91
 
92
 
93
 
94
+
95
 
96