|
from datetime import datetime |
|
import os |
|
import random |
|
import torch |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
|
|
import Image |
|
import ModelFormat |
|
from StyleTransferLoss import StyleTransferLoss |
|
import onnxruntime as rt |
|
|
|
import cv2 |
|
from insightface.data import get_image as ins_get_image |
|
from insightface.app import FaceAnalysis |
|
import face_align |
|
|
|
from StyleTransferModel_128 import StyleTransferModel |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
inswapper_128_path = 'inswapper_128.onnx' |
|
img_size = 128 |
|
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
|
inswapperInferenceSession = rt.InferenceSession(inswapper_128_path, providers=providers) |
|
|
|
faceAnalysis = FaceAnalysis(name='buffalo_l') |
|
faceAnalysis.prepare(ctx_id=0, det_size=(512, 512)) |
|
|
|
def get_device(): |
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
style_loss_fn = StyleTransferLoss().to(get_device()) |
|
|
|
def train(datasetDir, learning_rate=0.0001, model_path=None, outputModelFolder='', saveModelEachSteps = 1, stopAtSteps=None, logDir=None, previewDir=None, saveAs_onnx = False, resolutions = [128], enableDataAugmentation = False): |
|
device = get_device() |
|
print(f"Using device: {device}") |
|
|
|
model = StyleTransferModel().to(device) |
|
|
|
if model_path is not None: |
|
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) |
|
print(f"Loaded model from {model_path}") |
|
|
|
lastSteps = int(model_path.split('-')[-1].split('.')[0]) |
|
print(f"Resuming training from step {lastSteps}") |
|
else: |
|
lastSteps = 0 |
|
|
|
model.train() |
|
model = model.to(device) |
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
|
|
if logDir is not None: |
|
train_writer = SummaryWriter(os.path.join(logDir, "training")) |
|
val_writer = SummaryWriter(os.path.join(logDir, "validation")) |
|
|
|
steps = 0 |
|
|
|
image = os.listdir(datasetDir) |
|
|
|
resolutionIndex = 0 |
|
|
|
|
|
while True: |
|
start_time = datetime.now() |
|
|
|
resolution = resolutions[resolutionIndex%len(resolutions)] |
|
|
|
targetFaceIndex = random.randint(0, len(image)-1) |
|
sourceFaceIndex = random.randint(0, len(image)-1) |
|
|
|
target_img=cv2.imread(f"{datasetDir}/{image[targetFaceIndex]}") |
|
if enableDataAugmentation and steps % 2 == 0: |
|
target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY) |
|
target_img = cv2.cvtColor(target_img, cv2.COLOR_GRAY2BGR) |
|
faces = faceAnalysis.get(target_img) |
|
|
|
if targetFaceIndex != sourceFaceIndex: |
|
source_img = cv2.imread(f"{datasetDir}/{image[sourceFaceIndex]}") |
|
faces2 = faceAnalysis.get(source_img) |
|
else: |
|
faces2 = faces |
|
|
|
if len(faces) > 0 and len(faces2) > 0: |
|
new_aligned_face, _ = face_align.norm_crop2(target_img, faces[0].kps, img_size) |
|
blob = Image.getBlob(new_aligned_face) |
|
latent = Image.getLatent(faces2[0]) |
|
else: |
|
continue |
|
|
|
if targetFaceIndex != sourceFaceIndex: |
|
input = {inswapperInferenceSession.get_inputs()[0].name: blob, |
|
inswapperInferenceSession.get_inputs()[1].name: latent} |
|
|
|
expected_output = inswapperInferenceSession.run([inswapperInferenceSession.get_outputs()[0].name], input)[0] |
|
else: |
|
expected_output = blob |
|
|
|
expected_output_tensor = torch.from_numpy(expected_output).to(device) |
|
|
|
if resolution != 128: |
|
new_aligned_face, _ = face_align.norm_crop2(target_img, faces[0].kps, resolution) |
|
blob = Image.getBlob(new_aligned_face, (resolution, resolution)) |
|
|
|
latent_tensor = torch.from_numpy(latent).to(device) |
|
target_input_tensor = torch.from_numpy(blob).to(device) |
|
|
|
optimizer.zero_grad() |
|
output = model(target_input_tensor, latent_tensor) |
|
|
|
if (resolution != 128): |
|
output = F.interpolate(output, size=(128, 128), mode='bilinear', align_corners=False) |
|
|
|
content_loss, identity_loss = style_loss_fn(output, expected_output_tensor) |
|
|
|
loss = content_loss |
|
|
|
if identity_loss is not None: |
|
loss +=identity_loss |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
steps += 1 |
|
totalSteps = steps + lastSteps |
|
|
|
if logDir is not None: |
|
train_writer.add_scalar("Loss/total", loss.item(), totalSteps) |
|
train_writer.add_scalar("Loss/content_loss", content_loss.item(), totalSteps) |
|
|
|
if identity_loss is not None: |
|
train_writer.add_scalar("Loss/identity_loss", identity_loss.item(), totalSteps) |
|
|
|
elapsed_time = datetime.now() - start_time |
|
|
|
print(f"Total Steps: {totalSteps}, Step: {steps}, Loss: {loss.item():.4f}, Elapsed time: {elapsed_time}") |
|
|
|
if steps % saveModelEachSteps == 0: |
|
outputModelPath = f"reswapper-{totalSteps}.pth" |
|
if outputModelFolder != '': |
|
outputModelPath = f"{outputModelFolder}/{outputModelPath}" |
|
saveModel(model, outputModelPath) |
|
|
|
validation_total_loss, validation_content_loss, validation_identity_loss, swapped_face, swapped_face_256 = validate(outputModelPath) |
|
if previewDir is not None: |
|
cv2.imwrite(f"{previewDir}/{totalSteps}.jpg", swapped_face) |
|
cv2.imwrite(f"{previewDir}/{totalSteps}_256.jpg", swapped_face_256) |
|
|
|
if logDir is not None: |
|
val_writer.add_scalar("Loss/total", validation_total_loss.item(), totalSteps) |
|
val_writer.add_scalar("Loss/content_loss", validation_content_loss.item(), totalSteps) |
|
if validation_identity_loss is not None: |
|
val_writer.add_scalar("Loss/identity_loss", validation_identity_loss.item(), totalSteps) |
|
|
|
if saveAs_onnx : |
|
ModelFormat.save_as_onnx_model(outputModelPath) |
|
|
|
if stopAtSteps is not None and steps == stopAtSteps: |
|
exit() |
|
|
|
resolutionIndex += 1 |
|
|
|
def saveModel(model, outputModelPath): |
|
torch.save(model.state_dict(), outputModelPath) |
|
|
|
def load_model(model_path): |
|
device = get_device() |
|
model = StyleTransferModel().to(device) |
|
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) |
|
|
|
model.eval() |
|
return model |
|
|
|
def swap_face(model, target_face, source_face_latent): |
|
device = get_device() |
|
|
|
target_tensor = torch.from_numpy(target_face).to(device) |
|
source_tensor = torch.from_numpy(source_face_latent).to(device) |
|
|
|
with torch.no_grad(): |
|
swapped_tensor = model(target_tensor, source_tensor) |
|
|
|
swapped_face = Image.postprocess_face(swapped_tensor) |
|
|
|
return swapped_face, swapped_tensor |
|
|
|
|
|
test_img = ins_get_image('t1') |
|
|
|
test_faces = faceAnalysis.get(test_img) |
|
test_faces = sorted(test_faces, key = lambda x : x.bbox[0]) |
|
test_target_face, _ = face_align.norm_crop2(test_img, test_faces[0].kps, img_size) |
|
test_target_face = Image.getBlob(test_target_face) |
|
test_l = Image.getLatent(test_faces[2]) |
|
|
|
test_target_face_256, _ = face_align.norm_crop2(test_img, test_faces[0].kps, 256) |
|
test_target_face_256 = Image.getBlob(test_target_face_256, (256, 256)) |
|
|
|
test_input = {inswapperInferenceSession.get_inputs()[0].name: test_target_face, |
|
inswapperInferenceSession.get_inputs()[1].name: test_l} |
|
|
|
test_inswapperOutput = inswapperInferenceSession.run([inswapperInferenceSession.get_outputs()[0].name], test_input)[0] |
|
|
|
def validate(modelPath): |
|
model = load_model(modelPath) |
|
swapped_face, swapped_tensor= swap_face(model, test_target_face, test_l) |
|
swapped_face_256, _= swap_face(model, test_target_face_256, test_l) |
|
|
|
validation_content_loss, validation_identity_loss = style_loss_fn(swapped_tensor, torch.from_numpy(test_inswapperOutput).to(get_device())) |
|
|
|
validation_total_loss = validation_content_loss |
|
if validation_identity_loss is not None: |
|
validation_total_loss += validation_identity_loss |
|
|
|
return validation_total_loss, validation_content_loss, validation_identity_loss, swapped_face, swapped_face_256 |
|
|
|
def main(): |
|
outputModelFolder = "model" |
|
modelPath = None |
|
|
|
|
|
logDir = "training/log" |
|
previewDir = "training/preview" |
|
datasetDir = "FFHQ" |
|
|
|
os.makedirs(outputModelFolder, exist_ok=True) |
|
os.makedirs(previewDir, exist_ok=True) |
|
|
|
train( |
|
datasetDir=datasetDir, |
|
model_path=modelPath, |
|
learning_rate=0.0001, |
|
|
|
|
|
outputModelFolder=outputModelFolder, |
|
saveModelEachSteps = 1000, |
|
stopAtSteps = 70000, |
|
logDir=f"{logDir}/{datetime.now().strftime('%Y%m%d %H%M%S')}", |
|
previewDir=previewDir) |
|
|
|
if __name__ == "__main__": |
|
main() |