abhisheksan commited on
Commit
b4b3464
·
1 Parent(s): 29ee189

Refactor forgery_routes.py and routes.py

Browse files

- Refactor forgery_routes.py to import Response from fastapi
- Add detect_speech function to forgery_video_utils.py
- Update supported image and video formats in routes.py
- Remove verify_image_format function from image_utils.py

app/api/forgery_routes.py CHANGED
@@ -5,7 +5,8 @@ from app.services.audio_deepfake_service import AudioDeepfakeService
5
  from app.services.gan_detection_service import GANDetectionService
6
  from app.utils.file_utils import download_file, remove_temp_file, get_file_content
7
  from app.utils.forgery_image_utils import detect_face
8
- from app.utils.forgery_video_utils import extract_audio, extract_frames, compress_and_process_video
 
9
  import os
10
  import logging
11
  import traceback
@@ -45,7 +46,7 @@ async def detect_forgery(request: DetectForgeryRequest):
45
  firebase_filename = await download_file(file_url)
46
  logging.info(f"File downloaded and saved as: {firebase_filename}")
47
 
48
- if file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp']:
49
  logging.info(f"Processing image file: {firebase_filename}")
50
  return await process_image(firebase_filename)
51
  elif file_extension in ['mp4', 'avi', 'mov', 'flv', 'wmv']:
@@ -92,23 +93,29 @@ async def process_video(firebase_filename: str):
92
  logging.info(f"Video compressed: {compressed_video_filename}")
93
 
94
  audio_filename = await extract_audio(compressed_video_filename)
95
- logging.info(f"Audio extracted: {audio_filename}")
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  frames = await extract_frames(compressed_video_filename)
98
  logging.info(f"Frames extracted: {len(frames)} frames")
99
 
100
- results = {
101
- "audio_deepfake": None,
102
  "image_manipulation": [],
103
  "face_manipulation": [],
104
  "gan_detection": []
105
- }
106
-
107
- if audio_filename:
108
- results["audio_deepfake"] = audio_deepfake_service.detect_deepfake(audio_filename)
109
- logging.info(f"Audio deepfake detection result: {results['audio_deepfake']}")
110
- await remove_temp_file(audio_filename)
111
- logging.info(f"Temporary audio file removed: {audio_filename}")
112
 
113
  face_frames = []
114
  for i, frame in enumerate(frames):
 
5
  from app.services.gan_detection_service import GANDetectionService
6
  from app.utils.file_utils import download_file, remove_temp_file, get_file_content
7
  from app.utils.forgery_image_utils import detect_face
8
+ from app.utils.forgery_video_utils import extract_audio, extract_frames, compress_and_process_video, detect_speech # Adjust the import path if necessary
9
+
10
  import os
11
  import logging
12
  import traceback
 
46
  firebase_filename = await download_file(file_url)
47
  logging.info(f"File downloaded and saved as: {firebase_filename}")
48
 
49
+ if file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'gif', 'tiff', 'webp']:
50
  logging.info(f"Processing image file: {firebase_filename}")
51
  return await process_image(firebase_filename)
52
  elif file_extension in ['mp4', 'avi', 'mov', 'flv', 'wmv']:
 
93
  logging.info(f"Video compressed: {compressed_video_filename}")
94
 
95
  audio_filename = await extract_audio(compressed_video_filename)
96
+ if audio_filename:
97
+ logging.info(f"Audio extracted successfully: {audio_filename}")
98
+ audio_content = get_file_content(audio_filename)
99
+ if detect_speech(audio_content):
100
+ logging.info("Speech detected in the audio")
101
+ results = {"audio_deepfake": audio_deepfake_service.detect_deepfake(audio_filename)}
102
+ else:
103
+ logging.info("No speech detected in the audio")
104
+ results = {"audio_deepfake": {"prediction": "No speech detected", "confidence": 1.0, "raw_prediction": 1.0}}
105
+ await remove_temp_file(audio_filename)
106
+ logging.info(f"Temporary audio file removed: {audio_filename}")
107
+ else:
108
+ logging.warning("No audio detected or extracted from the video")
109
+ results = {"audio_deepfake": {"prediction": "No audio", "confidence": 1.0, "raw_prediction": 1.0}}
110
+
111
  frames = await extract_frames(compressed_video_filename)
112
  logging.info(f"Frames extracted: {len(frames)} frames")
113
 
114
+ results.update({
 
115
  "image_manipulation": [],
116
  "face_manipulation": [],
117
  "gan_detection": []
118
+ })
 
 
 
 
 
 
119
 
120
  face_frames = []
121
  for i, frame in enumerate(frames):
app/api/routes.py CHANGED
@@ -4,6 +4,7 @@ from app.services import video_service, image_service, antispoof_service
4
  from app.services.antispoof_service import antispoof_service
5
  from app.services.image_service import compare_images
6
  import logging
 
7
 
8
  router = APIRouter()
9
 
@@ -13,6 +14,18 @@ class ContentRequest(BaseModel):
13
  class CompareRequest(BaseModel):
14
  url1: str
15
  url2: str
 
 
 
 
 
 
 
 
 
 
 
 
16
  @router.get("/health")
17
  @router.head("/health")
18
  async def health_check():
@@ -20,8 +33,11 @@ async def health_check():
20
  Health check endpoint that responds to both GET and HEAD requests.
21
  """
22
  return Response(content="OK", media_type="text/plain")
 
23
  @router.post("/fingerprint")
24
  async def create_fingerprint(request: ContentRequest):
 
 
25
  try:
26
  result = await video_service.fingerprint_video(request.url)
27
  return {"message": "Fingerprint processing completed", "result": result}
@@ -31,6 +47,8 @@ async def create_fingerprint(request: ContentRequest):
31
 
32
  @router.post("/verify_video_only")
33
  async def verify_video_only(request: ContentRequest):
 
 
34
  try:
35
  result = await video_service.fingerprint_video(request.url)
36
  return {"message": "Video verification completed", "result": result}
@@ -40,6 +58,8 @@ async def verify_video_only(request: ContentRequest):
40
 
41
  @router.post("/verify_liveness")
42
  async def verify_liveness(request: ContentRequest):
 
 
43
  try:
44
  result = await antispoof_service.verify_liveness(request.url)
45
  return {"message": "Liveness verification completed", "result": result}
@@ -49,6 +69,8 @@ async def verify_liveness(request: ContentRequest):
49
 
50
  @router.post("/compare_videos")
51
  async def compare_videos_route(request: CompareRequest):
 
 
52
  try:
53
  result = await video_service.compare_videos(request.url1, request.url2)
54
  return {"message": "Video comparison completed", "result": result}
@@ -58,6 +80,8 @@ async def compare_videos_route(request: CompareRequest):
58
 
59
  @router.post("/verify_image")
60
  async def verify_image_route(request: ContentRequest):
 
 
61
  try:
62
  result = await image_service.verify_image(request.url)
63
  return {"message": "Image verification completed", "result": result}
@@ -67,6 +91,8 @@ async def verify_image_route(request: ContentRequest):
67
 
68
  @router.post("/compare_images")
69
  async def compare_images_route(request: CompareRequest):
 
 
70
  try:
71
  result = await compare_images(request.url1, request.url2)
72
  return {"message": "Image comparison completed", "result": result}
 
4
  from app.services.antispoof_service import antispoof_service
5
  from app.services.image_service import compare_images
6
  import logging
7
+ import os
8
 
9
  router = APIRouter()
10
 
 
14
  class CompareRequest(BaseModel):
15
  url1: str
16
  url2: str
17
+
18
+ SUPPORTED_VIDEO_FORMATS = ['mp4', 'avi', 'mov', 'flv', 'wmv']
19
+ SUPPORTED_IMAGE_FORMATS = ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff', 'webp']
20
+
21
+ def is_supported_video_format(url: str) -> bool:
22
+ file_extension = os.path.splitext(url)[1][1:].lower()
23
+ return file_extension in SUPPORTED_VIDEO_FORMATS
24
+
25
+ def is_supported_image_format(url: str) -> bool:
26
+ file_extension = os.path.splitext(url)[1][1:].lower()
27
+ return file_extension in SUPPORTED_IMAGE_FORMATS
28
+
29
  @router.get("/health")
30
  @router.head("/health")
31
  async def health_check():
 
33
  Health check endpoint that responds to both GET and HEAD requests.
34
  """
35
  return Response(content="OK", media_type="text/plain")
36
+
37
  @router.post("/fingerprint")
38
  async def create_fingerprint(request: ContentRequest):
39
+ if not is_supported_video_format(request.url):
40
+ raise HTTPException(status_code=400, detail="Video format not supported")
41
  try:
42
  result = await video_service.fingerprint_video(request.url)
43
  return {"message": "Fingerprint processing completed", "result": result}
 
47
 
48
  @router.post("/verify_video_only")
49
  async def verify_video_only(request: ContentRequest):
50
+ if not is_supported_video_format(request.url):
51
+ raise HTTPException(status_code=400, detail="Video format not supported")
52
  try:
53
  result = await video_service.fingerprint_video(request.url)
54
  return {"message": "Video verification completed", "result": result}
 
58
 
59
  @router.post("/verify_liveness")
60
  async def verify_liveness(request: ContentRequest):
61
+ if not is_supported_image_format(request.url):
62
+ raise HTTPException(status_code=400, detail="Image format not supported")
63
  try:
64
  result = await antispoof_service.verify_liveness(request.url)
65
  return {"message": "Liveness verification completed", "result": result}
 
69
 
70
  @router.post("/compare_videos")
71
  async def compare_videos_route(request: CompareRequest):
72
+ if not is_supported_video_format(request.url1) or not is_supported_video_format(request.url2):
73
+ raise HTTPException(status_code=400, detail="Video format not supported")
74
  try:
75
  result = await video_service.compare_videos(request.url1, request.url2)
76
  return {"message": "Video comparison completed", "result": result}
 
80
 
81
  @router.post("/verify_image")
82
  async def verify_image_route(request: ContentRequest):
83
+ if not is_supported_image_format(request.url):
84
+ raise HTTPException(status_code=400, detail="Image format not supported")
85
  try:
86
  result = await image_service.verify_image(request.url)
87
  return {"message": "Image verification completed", "result": result}
 
91
 
92
  @router.post("/compare_images")
93
  async def compare_images_route(request: CompareRequest):
94
+ if not is_supported_image_format(request.url1) or not is_supported_image_format(request.url2):
95
+ raise HTTPException(status_code=400, detail="Image format not supported")
96
  try:
97
  result = await compare_images(request.url1, request.url2)
98
  return {"message": "Image comparison completed", "result": result}
app/services/image_manipulation_service.py CHANGED
@@ -16,6 +16,9 @@ class ImageManipulationService:
16
  self.preprocessing_params = json.load(f)
17
 
18
  def convert_to_ela_image(self, image, quality):
 
 
 
19
  temp_buffer = io.BytesIO()
20
  image.save(temp_buffer, 'JPEG', quality=quality)
21
  temp_buffer.seek(0)
@@ -47,11 +50,13 @@ class ImageManipulationService:
47
  prediction = self.model.predict(prepared_image)
48
  predicted_class = int(np.argmax(prediction, axis=1)[0])
49
  confidence = float(np.max(prediction) * 100)
 
 
50
 
51
  result = {
52
  "class": self.class_names[predicted_class],
53
  "confidence": f"{confidence:.2f}%",
54
- "is_manipulated": bool(predicted_class == 0)
55
  }
56
 
57
  return result
 
16
  self.preprocessing_params = json.load(f)
17
 
18
  def convert_to_ela_image(self, image, quality):
19
+ if image.mode != 'RGB':
20
+ image = image.convert('RGB')
21
+
22
  temp_buffer = io.BytesIO()
23
  image.save(temp_buffer, 'JPEG', quality=quality)
24
  temp_buffer.seek(0)
 
50
  prediction = self.model.predict(prepared_image)
51
  predicted_class = int(np.argmax(prediction, axis=1)[0])
52
  confidence = float(np.max(prediction) * 100)
53
+
54
+ check_manipulated = bool(predicted_class == 0 and confidence > 90)
55
 
56
  result = {
57
  "class": self.class_names[predicted_class],
58
  "confidence": f"{confidence:.2f}%",
59
+ "is_manipulated": check_manipulated
60
  }
61
 
62
  return result
app/services/image_service.py CHANGED
@@ -1,4 +1,4 @@
1
- from app.utils.image_utils import verify_image_format, process_image, compare_images as compare_images_util
2
  from fastapi import HTTPException
3
  import logging
4
  from app.utils.file_utils import download_file, remove_temp_file
@@ -7,7 +7,6 @@ async def verify_image(image_url: str):
7
  firebase_filename = None
8
  try:
9
  firebase_filename = await download_file(image_url)
10
- verify_image_format(firebase_filename)
11
 
12
  image_hash = process_image(firebase_filename)
13
  return {"image_hash": image_hash}
@@ -26,10 +25,6 @@ async def compare_images(image_url1: str, image_url2: str):
26
  firebase_filename1 = await download_file(image_url1)
27
  firebase_filename2 = await download_file(image_url2)
28
 
29
- # Verify the image format for both images
30
- verify_image_format(firebase_filename1)
31
- verify_image_format(firebase_filename2)
32
-
33
  # Compare the images using the utility function
34
  comparison_result = compare_images_util(firebase_filename1, firebase_filename2)
35
 
 
1
+ from app.utils.image_utils import process_image, compare_images as compare_images_util
2
  from fastapi import HTTPException
3
  import logging
4
  from app.utils.file_utils import download_file, remove_temp_file
 
7
  firebase_filename = None
8
  try:
9
  firebase_filename = await download_file(image_url)
 
10
 
11
  image_hash = process_image(firebase_filename)
12
  return {"image_hash": image_hash}
 
25
  firebase_filename1 = await download_file(image_url1)
26
  firebase_filename2 = await download_file(image_url2)
27
 
 
 
 
 
28
  # Compare the images using the utility function
29
  comparison_result = compare_images_util(firebase_filename1, firebase_filename2)
30
 
app/utils/image_utils.py CHANGED
@@ -7,15 +7,6 @@ import imghdr
7
  from fastapi import HTTPException
8
  from app.utils.file_utils import get_file_content
9
 
10
- SUPPORTED_IMAGE_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp']
11
-
12
- def verify_image_format(filename: str):
13
- content = get_file_content(filename)
14
- file_ext = '.' + (imghdr.what(BytesIO(content)) or '')
15
-
16
- if file_ext not in SUPPORTED_IMAGE_FORMATS:
17
- raise HTTPException(status_code=400, detail=f"Unsupported image format. Supported formats are: {', '.join(SUPPORTED_IMAGE_FORMATS)}")
18
-
19
  def preprocess_image(image: Union[str, np.ndarray, Image.Image], hash_size: int = 32) -> np.ndarray:
20
  if isinstance(image, str):
21
  content = get_file_content(image)
 
7
  from fastapi import HTTPException
8
  from app.utils.file_utils import get_file_content
9
 
 
 
 
 
 
 
 
 
 
10
  def preprocess_image(image: Union[str, np.ndarray, Image.Image], hash_size: int = 32) -> np.ndarray:
11
  if isinstance(image, str):
12
  content = get_file_content(image)