import face_detection import torch from models import Wav2Lip #import os #os.system("!pip show Wav2Lip > /content/temp.txt") def _load(checkpoint_path): device = 'cuda' if torch.cuda.is_available() else 'cpu' if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Wav2Lip() print("Load checkpoint from: {}".format(path)) checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval()