Dileep7729 commited on
Commit
dcff825
·
verified ·
1 Parent(s): 988ceee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -2,85 +2,84 @@ import gradio as gr
2
  from transformers import CLIPModel, CLIPProcessor
3
  from PIL import Image
4
 
5
- # Load Model and Processor
6
  model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
7
- print("Initializing the model and processor...")
 
8
 
9
  try:
 
10
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
11
  processor = CLIPProcessor.from_pretrained(model_name)
12
  print("Model and processor loaded successfully.")
13
  except Exception as e:
14
- print(f"Error loading model or processor: {e}")
15
- raise RuntimeError(f"Failed to load the model: {e}")
16
-
17
 
18
- # Inference Function
19
  def classify_image(image):
20
  """
21
- Classifies an image as 'safe' or 'unsafe' using the CLIP model.
22
 
23
  Args:
24
  image (PIL.Image.Image): Uploaded image.
25
-
26
  Returns:
27
- Tuple: Predicted category and probabilities for "safe" and "unsafe".
28
  """
29
  try:
30
  print("Starting image classification...")
31
 
32
- # Validate image input
33
  if image is None:
34
  raise ValueError("No image provided. Please upload a valid image.")
 
 
35
  if not hasattr(image, "convert"):
36
- raise ValueError("Uploaded file is not a valid image format (JPEG, PNG, etc.).")
37
 
38
- # Define classification categories
39
  categories = ["safe", "unsafe"]
40
- print(f"Using categories: {categories}")
41
 
42
- # Process image using CLIPProcessor
 
43
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
44
- print("Image processed successfully.")
45
 
46
- # Perform inference
 
47
  outputs = model(**inputs)
48
- print("Inference completed successfully.")
49
 
50
- # Calculate probabilities
51
- logits_per_image = outputs.logits_per_image
52
- probs = logits_per_image.softmax(dim=1)
53
- print(f"Probabilities: {probs}")
54
 
55
  # Extract probabilities for each category
56
- safe_prob = probs[0][0].item() * 100
57
- unsafe_prob = probs[0][1].item() * 100
58
- print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%")
59
 
60
- # Determine the predicted category
61
- predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
62
- print(f"Predicted Category: {predicted_category}")
63
-
64
- # Return category and probabilities
65
- return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}
66
 
67
  except Exception as e:
68
  print(f"Error during classification: {e}")
69
- return "Error", {"safe": "N/A", "unsafe": "N/A"}
70
-
71
 
72
- # Gradio Interface
73
  iface = gr.Interface(
74
  fn=classify_image,
75
- inputs=gr.Image(type="pil"), # Accept image input
76
- outputs=[
77
- gr.Textbox(label="Predicted Category"), # Predicted category
78
- gr.Label(label="Probabilities"), # Probabilities as progress bars
79
- ],
80
  title="Content Safety Classification",
81
- description="Upload an image to classify it as 'safe' or 'unsafe' with probabilities.",
82
  )
83
 
 
84
  if __name__ == "__main__":
85
  print("Launching the Gradio interface...")
86
  iface.launch()
@@ -107,6 +106,8 @@ if __name__ == "__main__":
107
 
108
 
109
 
 
 
110
 
111
 
112
 
 
2
  from transformers import CLIPModel, CLIPProcessor
3
  from PIL import Image
4
 
5
+ # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
6
  model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
7
+
8
+ print("Initializing the application...")
9
 
10
  try:
11
+ print("Loading the model from Hugging Face Model Hub...")
12
  model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
13
  processor = CLIPProcessor.from_pretrained(model_name)
14
  print("Model and processor loaded successfully.")
15
  except Exception as e:
16
+ print(f"Error loading the model or processor: {e}")
17
+ raise RuntimeError(f"Failed to load model: {e}")
 
18
 
19
+ # Step 2: Define the Inference Function
20
  def classify_image(image):
21
  """
22
+ Classify an image as 'safe' or 'unsafe' and return probabilities.
23
 
24
  Args:
25
  image (PIL.Image.Image): Uploaded image.
26
+
27
  Returns:
28
+ dict: Classification results or an error message.
29
  """
30
  try:
31
  print("Starting image classification...")
32
 
33
+ # Validate input
34
  if image is None:
35
  raise ValueError("No image provided. Please upload a valid image.")
36
+
37
+ # Validate image format
38
  if not hasattr(image, "convert"):
39
+ raise ValueError("Invalid image format. Please upload a valid image (JPEG, PNG, etc.).")
40
 
41
+ # Define categories
42
  categories = ["safe", "unsafe"]
 
43
 
44
+ # Process the image with the processor
45
+ print("Processing the image...")
46
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
47
+ print(f"Processed inputs: {inputs}")
48
 
49
+ # Run inference with the model
50
+ print("Running model inference...")
51
  outputs = model(**inputs)
52
+ print(f"Model outputs: {outputs}")
53
 
54
+ # Extract logits and probabilities
55
+ logits_per_image = outputs.logits_per_image # Image-text similarity scores
56
+ probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
57
+ print(f"Calculated probabilities: {probs}")
58
 
59
  # Extract probabilities for each category
60
+ safe_prob = probs[0][0].item() * 100 # Safe percentage
61
+ unsafe_prob = probs[0][1].item() * 100 # Unsafe percentage
 
62
 
63
+ # Return results
64
+ return {
65
+ "safe": f"{safe_prob:.2f}%",
66
+ "unsafe": f"{unsafe_prob:.2f}%"
67
+ }
 
68
 
69
  except Exception as e:
70
  print(f"Error during classification: {e}")
71
+ return {"Error": str(e)}
 
72
 
73
+ # Step 3: Set Up Gradio Interface
74
  iface = gr.Interface(
75
  fn=classify_image,
76
+ inputs=gr.Image(type="pil"),
77
+ outputs=gr.Label(label="Output"), # Display probabilities as a percentage scale
 
 
 
78
  title="Content Safety Classification",
79
+ description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
80
  )
81
 
82
+ # Step 4: Launch Gradio Interface
83
  if __name__ == "__main__":
84
  print("Launching the Gradio interface...")
85
  iface.launch()
 
106
 
107
 
108
 
109
+
110
+
111
 
112
 
113