|
from collections import namedtuple |
|
import torch |
|
from torch.utils import model_zoo |
|
import requests |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
|
|
from src.FaceDetector.face_detector import FaceDetector |
|
from src.FaceId.faceid import FaceId |
|
from src.Generator.fs_networks_fix import Generator_Adain_Upsample |
|
from src.PostProcess.ParsingModel.model import BiSeNet |
|
from src.PostProcess.GFPGAN.gfpgan import GFPGANer |
|
from src.Blend.blend import BlendModule |
|
|
|
|
|
model = namedtuple("model", ["url", "model"]) |
|
|
|
models = { |
|
"face_detector": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx", |
|
model=FaceDetector, |
|
), |
|
"arcface": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit", |
|
model=FaceId, |
|
), |
|
"generator_224": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth", |
|
model=Generator_Adain_Upsample, |
|
), |
|
"generator_512": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth", |
|
model=Generator_Adain_Upsample, |
|
), |
|
"parsing_model": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth", |
|
model=BiSeNet, |
|
), |
|
"gfpgan": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth", |
|
model=GFPGANer, |
|
), |
|
"blend_module": model( |
|
url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit", |
|
model=BlendModule |
|
) |
|
} |
|
|
|
|
|
def get_model( |
|
model_name: str, |
|
device: torch.device, |
|
load_state_dice: bool, |
|
model_path: Path, |
|
**kwargs, |
|
): |
|
dst_dir = Path.cwd() / "weights" |
|
dst_dir.mkdir(exist_ok=True) |
|
|
|
url = models[model_name].url if not model_path.is_file() else str(model_path) |
|
|
|
if load_state_dice: |
|
model = models[model_name].model(**kwargs) |
|
|
|
if Path(url).is_file(): |
|
state_dict = torch.load(url) |
|
else: |
|
state_dict = model_zoo.load_url( |
|
url, |
|
model_dir=str(dst_dir), |
|
progress=True, |
|
map_location="cpu", |
|
) |
|
|
|
model.load_state_dict(state_dict) |
|
|
|
model.to(device) |
|
model.eval() |
|
else: |
|
dst_path = Path(url) |
|
|
|
if not dst_path.is_file(): |
|
dst_path = dst_dir / Path(url).name |
|
|
|
if not dst_path.is_file(): |
|
print(f"Downloading: '{url}' to {dst_path}") |
|
response = requests.get(url, stream=True) |
|
if int(response.status_code) == 200: |
|
file_size = int(response.headers["Content-Length"]) / (2 ** 20) |
|
chunk_size = 1024 |
|
bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n:3.1f}M/{total:3.1f}M [{elapsed}<{remaining}]" |
|
with open(dst_path, "wb") as handle: |
|
with tqdm(total=file_size, bar_format=bar_format) as pbar: |
|
for data in response.iter_content(chunk_size=chunk_size): |
|
handle.write(data) |
|
pbar.update(len(data) / (2 ** 20)) |
|
else: |
|
raise ValueError( |
|
f"Couldn't download weights {url}. Specify weights for the '{model_name}' model manually." |
|
) |
|
|
|
kwargs.update({"model_path": str(dst_path), "device": device}) |
|
|
|
model = models[model_name].model(**kwargs) |
|
|
|
return model |
|
|