prithivMLmods commited on
Commit
413a6e6
·
verified ·
1 Parent(s): c4fcb59

Update deepfake_vs_real.py

Browse files
Files changed (1) hide show
  1. deepfake_vs_real.py +10 -15
deepfake_vs_real.py CHANGED
@@ -1,15 +1,13 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import AutoImageProcessor
4
- from transformers import SiglipForImageClassification
5
- from transformers.image_utils import load_image
6
  from PIL import Image
7
  import torch
8
 
9
- # Load model and processor
10
- model_name = "prithivMLmods/Deepfake-vs-Real-8000"
11
- model = SiglipForImageClassification.from_pretrained(model_name)
12
- processor = AutoImageProcessor.from_pretrained(model_name)
13
 
14
  @spaces.GPU
15
  def deepfake_classification(image):
@@ -20,20 +18,17 @@ def deepfake_classification(image):
20
  with torch.no_grad():
21
  outputs = model(**inputs)
22
  logits = outputs.logits
23
- probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
24
 
25
- labels = {
26
- "0": "Deepfake", "1": "Real one"
27
- }
28
- predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
29
-
30
- return predictions
31
 
32
  # Create Gradio interface
33
  iface = gr.Interface(
34
  fn=deepfake_classification,
35
  inputs=gr.Image(type="numpy"),
36
- outputs=gr.Label(label="Prediction Scores"),
37
  title="Deepfake vs. Real Image Classification",
38
  description="Upload an image to determine if it's a Deepfake or a Real one."
39
  )
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import ViTForImageClassification, ViTImageProcessor
 
 
4
  from PIL import Image
5
  import torch
6
 
7
+ # Load the model and processor
8
+ model_name = "prithivMLmods/Deep-Fake-Detector-v2-Model"
9
+ model = ViTForImageClassification.from_pretrained(model_name)
10
+ processor = ViTImageProcessor.from_pretrained(model_name)
11
 
12
  @spaces.GPU
13
  def deepfake_classification(image):
 
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
  logits = outputs.logits
21
+ predicted_class = torch.argmax(logits, dim=1).item()
22
 
23
+ # Get label mapping
24
+ label = model.config.id2label[predicted_class] if hasattr(model.config, "id2label") else str(predicted_class)
25
+ return {label: 1.0} # Gradio Label output expects a dictionary
 
 
 
26
 
27
  # Create Gradio interface
28
  iface = gr.Interface(
29
  fn=deepfake_classification,
30
  inputs=gr.Image(type="numpy"),
31
+ outputs=gr.Label(label="Prediction"),
32
  title="Deepfake vs. Real Image Classification",
33
  description="Upload an image to determine if it's a Deepfake or a Real one."
34
  )