|
import os.path |
|
import re |
|
import torch |
|
import time |
|
import tempfile |
|
|
|
import streamlit as st |
|
from training.zoo.classifiers import DeepFakeClassifier |
|
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set |
|
|
|
|
|
def load_model(): |
|
path = 'weights/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23' |
|
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") |
|
print("loading state dict {}".format(path)) |
|
checkpoint = torch.load(path, map_location="cpu") |
|
state_dict = checkpoint.get("state_dict", checkpoint) |
|
model.load_state_dict( |
|
{re.sub("^module.", "", k): v for k, v in state_dict.items()}, |
|
strict=True) |
|
model.eval() |
|
del checkpoint |
|
return model |
|
|
|
|
|
def write_bytesio_to_file(filename, bytesio): |
|
with open(filename, "wb") as outfile: |
|
outfile.write(bytesio.getbuffer()) |
|
|
|
|
|
def load_video(): |
|
uploaded_file = st.file_uploader(label='Pick a video (mp4) file to test') |
|
if uploaded_file is not None: |
|
video_data = uploaded_file.getvalue() |
|
tfile = tempfile.NamedTemporaryFile(delete=False) |
|
tfile.write(video_data) |
|
return tfile.name |
|
else: |
|
return None |
|
|
|
|
|
def inference(model, test_video): |
|
frames_per_video = 32 |
|
video_reader = VideoReader() |
|
video_read_fn = lambda x: video_reader.read_frames( |
|
x, num_frames=frames_per_video) |
|
face_extractor = FaceExtractor(video_read_fn) |
|
input_size = 380 |
|
strategy = confident_strategy |
|
|
|
test_videos = [test_video] |
|
print("Predicting {} videos".format(len(test_videos))) |
|
models = [model] |
|
predictions = predict_on_video_set(face_extractor=face_extractor, |
|
input_size=input_size, models=models, |
|
strategy=strategy, |
|
frames_per_video=frames_per_video, |
|
videos=test_videos, |
|
num_workers=6, test_dir="test_video") |
|
st.write("Prediction: ", predictions[0]) |
|
|
|
|
|
def main(): |
|
st.title('Deepfake video inference demo') |
|
model = load_model() |
|
video_data_path = load_video() |
|
|
|
if video_data_path is not None and os.path.exists(video_data_path): |
|
st.video(video_data_path) |
|
|
|
result = st.button('Run on video') |
|
if result: |
|
st.write("Inference on video...") |
|
stime = time.time() |
|
inference(model, video_data_path) |
|
st.write("Elapsed time: ", time.time() - stime, " seconds") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|