Dileep7729 commited on
Commit
a41b014
·
verified ·
1 Parent(s): 4df31f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -28
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import CLIPModel, CLIPProcessor
3
 
4
  # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
@@ -10,45 +11,59 @@ processor = CLIPProcessor.from_pretrained(model_name)
10
  print("Model loaded successfully.")
11
 
12
  # Step 2: Define the Inference Function
13
- def classify_image(image, class_names):
14
  """
15
- Classify an image as 'safe' or 'unsafe' using the fine-tuned CLIP model.
16
 
17
  Args:
18
  image (PIL.Image.Image): The input image.
19
- class_names (str): Comma-separated class names (e.g., "safe, unsafe").
20
 
21
  Returns:
22
- dict: A dictionary containing class names and their probabilities.
23
  """
24
- # Split class names from comma-separated input
25
- labels = [label.strip() for label in class_names.split(",") if label.strip()]
26
- if not labels:
27
- return {"Error": "Please enter at least one valid class name."}
28
-
29
- # Process the image and labels
30
- inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
31
- outputs = model(**inputs)
32
- logits_per_image = outputs.logits_per_image # Get image-text similarity scores
33
- probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
34
-
35
- # Extract labels with their corresponding probabilities
36
- result = {label: probs[0][i].item() for i, label in enumerate(labels)}
37
- return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Step 3: Set Up Gradio Interface
40
  iface = gr.Interface(
41
  fn=classify_image,
42
- inputs=[
43
- gr.Image(type="pil"),
44
- gr.Textbox(
45
- label="Possible class names (comma-separated)",
46
- placeholder="e.g., safe, unsafe"
47
- )
48
- ],
49
- outputs=gr.Label(num_top_classes=2),
50
- title="Content Safety Classification",
51
- description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model."
52
  )
53
 
54
  # Step 4: Launch Gradio Interface
@@ -68,3 +83,4 @@ if __name__ == "__main__":
68
 
69
 
70
 
 
 
1
  import gradio as gr
2
+ import gradio as gr
3
  from transformers import CLIPModel, CLIPProcessor
4
 
5
  # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
 
11
  print("Model loaded successfully.")
12
 
13
  # Step 2: Define the Inference Function
14
+ def classify_image(image):
15
  """
16
+ Classify an image as 'safe' or 'unsafe' with probabilities and subcategories.
17
 
18
  Args:
19
  image (PIL.Image.Image): The input image.
 
20
 
21
  Returns:
22
+ dict: A dictionary containing main categories (safe/unsafe) and their probabilities.
23
  """
24
+ # Define the predefined categories
25
+ main_categories = ["safe", "unsafe"]
26
+ safe_subcategories = ["retail product", "other safe content"]
27
+ unsafe_subcategories = ["harmful", "violent", "sexual", "self harm"]
28
+
29
+ # Process the image with the main categories
30
+ main_inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
31
+ main_outputs = model(**main_inputs)
32
+ logits_per_image = main_outputs.logits_per_image # Image-text similarity scores
33
+ main_probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
34
+
35
+ # Determine the main category
36
+ main_result = {main_categories[i]: main_probs[0][i].item() for i in range(len(main_categories))}
37
+ main_category = max(main_result, key=main_result.get) # Either "safe" or "unsafe"
38
+
39
+ # Process the image with subcategories based on the main category
40
+ subcategories = safe_subcategories if main_category == "safe" else unsafe_subcategories
41
+ sub_inputs = processor(text=subcategories, images=image, return_tensors="pt", padding=True)
42
+ sub_outputs = model(**sub_inputs)
43
+ sub_logits = sub_outputs.logits_per_image
44
+ sub_probs = sub_logits.softmax(dim=1) # Convert logits to probabilities
45
+
46
+ # Create a structured result
47
+ result = {
48
+ "Main Category": main_category,
49
+ "Main Probabilities": main_result,
50
+ "Subcategory Probabilities": {
51
+ subcategories[i]: sub_probs[0][i].item() for i in range(len(subcategories))
52
+ }
53
+ }
54
+ return result
55
 
56
  # Step 3: Set Up Gradio Interface
57
  iface = gr.Interface(
58
  fn=classify_image,
59
+ inputs=gr.Image(type="pil"),
60
+ outputs="json",
61
+ title="Enhanced Content Safety Classification",
62
+ description=(
63
+ "Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model. "
64
+ "For 'safe', identify subcategories such as 'retail product'. "
65
+ "For 'unsafe', identify subcategories such as 'harmful', 'violent', 'sexual', or 'self harm'."
66
+ ),
 
 
67
  )
68
 
69
  # Step 4: Launch Gradio Interface
 
83
 
84
 
85
 
86
+