import spaces
from flask import Flask, request, jsonify
import os
from werkzeug.utils import secure_filename
import cv2
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os

app = Flask(__name__)

# Configuration
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'webm'}
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024

os.makedirs(UPLOAD_FOLDER, exist_ok=True)

# Device configuration
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()

model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
# Model Credits: https://huggingface.co/spaces/dhairyashah/deepfake-alpha-version/blob/main/CREDITS.md
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@spaces.GPU
def process_frame(frame):
    face = mtcnn(frame)
    if face is None:
        return None, None

    face = face.unsqueeze(0)
    face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)

    face = face.to(DEVICE)
    face = face.to(torch.float32)
    face = face / 255.0

    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "fake" if output.item() >= 0.5 else "real"

    return prediction, output.item()

@spaces.GPU
def analyze_video(video_path, sample_rate=30):
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    fake_count = 0
    total_processed = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % sample_rate == 0:
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            prediction, confidence = process_frame(rgb_frame)

            if prediction is not None:
                total_processed += 1
                if prediction == "fake":
                    fake_count += 1

        frame_count += 1

    cap.release()

    if total_processed > 0:
        fake_percentage = (fake_count / total_processed) * 100
        return fake_percentage
    else:
        return 0

@app.route('/analyze', methods=['POST'])
def analyze_video_api():
    if 'video' not in request.files:
        return jsonify({'error': 'No video file provided'}), 400

    file = request.files['video']

    if file.filename == '':
        return jsonify({'error': 'No selected file'}), 400

    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)

        try:
            fake_percentage = analyze_video(filepath)
            os.remove(filepath)  # Remove the file after analysis

            result = {
                'fake_percentage': round(fake_percentage, 2),
                'is_likely_deepfake': fake_percentage >= 60
            }

            return jsonify(result), 200
        except Exception as e:
            os.remove(filepath)  # Remove the file if an error occurs
            return jsonify({'error': str(e)}), 500
    else:
        return jsonify({'error': f'Invalid file type: {file.filename}'}), 400

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)