LPX55 commited on
Commit
a9d7990
·
verified ·
1 Parent(s): 1e046c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -1,27 +1,34 @@
1
  import gradio as gr
2
- # from transformers import AutoBackbone, AutoModelForImageClassification, AutoImageProcessor, Swinv2ForImageClassification
3
  from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
4
  from torchvision import transforms
5
 
6
- # model = AutoModelForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
7
- # image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
8
  image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
9
- # image_processor = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
10
  model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
11
-
12
  clf = pipeline(model=model, task="image-classification", image_processor=image_processor)
13
 
 
14
  class_names = ['artificial', 'real']
15
 
16
  def predict_image(img):
17
- img = transforms.ToPILImage()(img)
18
- img = transforms.Resize((256,256))(img)
19
- prediction=clf.predict(img)
20
- print(prediction)
21
-
22
- return {class_names[i]: float(prediction[i]["score"]) for i in range(2)}
 
 
 
 
 
 
 
 
 
 
23
 
 
24
  image = gr.Image(label="Image to Analyze", sources=['upload'])
25
  label = gr.Label(num_top_classes=2)
26
-
27
  gr.Interface(fn=predict_image, inputs=image, outputs=label, title="AI Generated Classification").launch()
 
1
  import gradio as gr
 
2
  from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
3
  from torchvision import transforms
4
 
5
+ # Load the model and processor
 
6
  image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
 
7
  model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
 
8
  clf = pipeline(model=model, task="image-classification", image_processor=image_processor)
9
 
10
+ # Define class names
11
  class_names = ['artificial', 'real']
12
 
13
  def predict_image(img):
14
+ # Convert the image to a PIL Image and resize it
15
+ img = transforms.ToPILImage()(img)
16
+ img = transforms.Resize((256, 256))(img)
17
+
18
+ # Get the prediction
19
+ prediction = clf(img)
20
+
21
+ # Process the prediction to match the class names
22
+ result = {pred['label']: pred['score'] for pred in prediction}
23
+
24
+ # Ensure the result dictionary contains both class names
25
+ for class_name in class_names:
26
+ if class_name not in result:
27
+ result[class_name] = 0.0
28
+
29
+ return result
30
 
31
+ # Define the Gradio interface
32
  image = gr.Image(label="Image to Analyze", sources=['upload'])
33
  label = gr.Label(num_top_classes=2)
 
34
  gr.Interface(fn=predict_image, inputs=image, outputs=label, title="AI Generated Classification").launch()