Dileep7729 commited on
Commit
4df31f3
·
verified ·
1 Parent(s): bbfef86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -1,26 +1,26 @@
1
- import os
2
- import torch
3
- from torchvision import transforms
4
- from PIL import Image
5
- from transformers import CLIPModel, CLIPProcessor
6
  import gradio as gr
 
7
 
8
- # Step 1: Ensure Fine-Tuned Model is Available
9
- fine_tuned_model_path = "fine-tuned-model"
10
-
11
- if not os.path.exists(fine_tuned_model_path):
12
- raise FileNotFoundError(
13
- f"The fine-tuned model is missing. Ensure that the fine-tuned model files are available in the '{fine_tuned_model_path}' directory."
14
- )
15
 
16
- # Step 2: Load Fine-Tuned Model
17
- print("Loading fine-tuned model...")
18
- model = CLIPModel.from_pretrained(fine_tuned_model_path)
19
- processor = CLIPProcessor.from_pretrained(fine_tuned_model_path)
20
- print("Fine-tuned model loaded successfully.")
21
 
22
- # Step 3: Define Gradio Inference Function
23
  def classify_image(image, class_names):
 
 
 
 
 
 
 
 
 
 
24
  # Split class names from comma-separated input
25
  labels = [label.strip() for label in class_names.split(",") if label.strip()]
26
  if not labels:
@@ -29,26 +29,29 @@ def classify_image(image, class_names):
29
  # Process the image and labels
30
  inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
31
  outputs = model(**inputs)
32
- logits_per_image = outputs.logits_per_image
33
  probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
34
 
35
  # Extract labels with their corresponding probabilities
36
  result = {label: probs[0][i].item() for i, label in enumerate(labels)}
37
  return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
38
 
39
- # Step 4: Set Up Gradio Interface
40
  iface = gr.Interface(
41
  fn=classify_image,
42
  inputs=[
43
  gr.Image(type="pil"),
44
- gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe")
 
 
 
45
  ],
46
  outputs=gr.Label(num_top_classes=2),
47
  title="Content Safety Classification",
48
- description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
49
  )
50
 
51
- # Step 5: Launch Gradio Interface
52
  if __name__ == "__main__":
53
  iface.launch()
54
 
@@ -64,3 +67,4 @@ if __name__ == "__main__":
64
 
65
 
66
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import CLIPModel, CLIPProcessor
3
 
4
+ # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
5
+ model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
 
 
 
 
 
6
 
7
+ print("Loading the fine-tuned model from Hugging Face Model Hub...")
8
+ model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
9
+ processor = CLIPProcessor.from_pretrained(model_name)
10
+ print("Model loaded successfully.")
 
11
 
12
+ # Step 2: Define the Inference Function
13
  def classify_image(image, class_names):
14
+ """
15
+ Classify an image as 'safe' or 'unsafe' using the fine-tuned CLIP model.
16
+
17
+ Args:
18
+ image (PIL.Image.Image): The input image.
19
+ class_names (str): Comma-separated class names (e.g., "safe, unsafe").
20
+
21
+ Returns:
22
+ dict: A dictionary containing class names and their probabilities.
23
+ """
24
  # Split class names from comma-separated input
25
  labels = [label.strip() for label in class_names.split(",") if label.strip()]
26
  if not labels:
 
29
  # Process the image and labels
30
  inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
31
  outputs = model(**inputs)
32
+ logits_per_image = outputs.logits_per_image # Get image-text similarity scores
33
  probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
34
 
35
  # Extract labels with their corresponding probabilities
36
  result = {label: probs[0][i].item() for i, label in enumerate(labels)}
37
  return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
38
 
39
+ # Step 3: Set Up Gradio Interface
40
  iface = gr.Interface(
41
  fn=classify_image,
42
  inputs=[
43
  gr.Image(type="pil"),
44
+ gr.Textbox(
45
+ label="Possible class names (comma-separated)",
46
+ placeholder="e.g., safe, unsafe"
47
+ )
48
  ],
49
  outputs=gr.Label(num_top_classes=2),
50
  title="Content Safety Classification",
51
+ description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model."
52
  )
53
 
54
+ # Step 4: Launch Gradio Interface
55
  if __name__ == "__main__":
56
  iface.launch()
57
 
 
67
 
68
 
69
 
70
+