Dileep7729 commited on
Commit
d237a07
·
verified ·
1 Parent(s): 610954a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -21
app.py CHANGED
@@ -12,7 +12,7 @@ print("Model loaded successfully.")
12
  # Step 2: Define the Inference Function
13
  def classify_image(image):
14
  """
15
- Classify an image as 'safe' or 'unsafe' with probabilities and display as a progress bar.
16
 
17
  Args:
18
  image (PIL.Image.Image): The input image.
@@ -20,32 +20,52 @@ def classify_image(image):
20
  Returns:
21
  dict: A dictionary containing probabilities for 'safe' and 'unsafe'.
22
  """
23
- # Define the main categories
24
- main_categories = ["safe", "unsafe"]
25
-
26
- # Process the image with the main categories
27
- inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
28
- outputs = model(**inputs)
29
- logits_per_image = outputs.logits_per_image # Image-text similarity scores
30
- probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
31
-
32
- # Extract the probabilities
33
- safe_probability = probs[0][0].item() * 100 # Safe percentage
34
- unsafe_probability = probs[0][1].item() * 100 # Unsafe percentage
35
-
36
- # Return probabilities as a dictionary for display in Gradio's Label component
37
- return {
38
- "safe": f"{safe_probability:.2f}%",
39
- "unsafe": f"{unsafe_probability:.2f}%"
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Step 3: Set Up Gradio Interface
43
  iface = gr.Interface(
44
  fn=classify_image,
45
  inputs=gr.Image(type="pil"),
46
- outputs=gr.Label(label="Output"),
47
  title="Content Safety Classification",
48
- description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
 
 
 
49
  )
50
 
51
  # Step 4: Launch Gradio Interface
@@ -68,3 +88,4 @@ if __name__ == "__main__":
68
 
69
 
70
 
 
 
12
  # Step 2: Define the Inference Function
13
  def classify_image(image):
14
  """
15
+ Classify an image as 'safe' or 'unsafe' with the corresponding percentage.
16
 
17
  Args:
18
  image (PIL.Image.Image): The input image.
 
20
  Returns:
21
  dict: A dictionary containing probabilities for 'safe' and 'unsafe'.
22
  """
23
+ try:
24
+ # Debug: Check if the image is loaded
25
+ if image is None:
26
+ raise ValueError("No image provided. Please upload an image.")
27
+
28
+ # Define the main categories
29
+ main_categories = ["safe", "unsafe"]
30
+
31
+ # Process the image with the model processor
32
+ print("Processing the image...")
33
+ inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
34
+ print(f"Inputs processed: {inputs}")
35
+
36
+ # Perform inference using the model
37
+ outputs = model(**inputs)
38
+ print(f"Model outputs: {outputs}")
39
+
40
+ # Extract probabilities for each category
41
+ logits_per_image = outputs.logits_per_image # Image-text similarity scores
42
+ probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
43
+
44
+ # Safe and unsafe probabilities
45
+ safe_probability = probs[0][0].item() * 100 # Convert to percentage
46
+ unsafe_probability = probs[0][1].item() * 100 # Convert to percentage
47
+
48
+ print(f"Safe: {safe_probability:.2f}%, Unsafe: {unsafe_probability:.2f}%")
49
+
50
+ # Return the results as a dictionary for display in Gradio
51
+ return {
52
+ "safe": f"{safe_probability:.2f}%",
53
+ "unsafe": f"{unsafe_probability:.2f}%"
54
+ }
55
+ except Exception as e:
56
+ print(f"Error during inference: {str(e)}")
57
+ return {"Error": str(e)}
58
 
59
  # Step 3: Set Up Gradio Interface
60
  iface = gr.Interface(
61
  fn=classify_image,
62
  inputs=gr.Image(type="pil"),
63
+ outputs=gr.Label(label="Output"), # Use Gradio's Label component for progress bar display
64
  title="Content Safety Classification",
65
+ description=(
66
+ "Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities. "
67
+ "The model will analyze the image and provide probabilities for each category."
68
+ ),
69
  )
70
 
71
  # Step 4: Launch Gradio Interface
 
88
 
89
 
90
 
91
+