hg / train.py
victorgg's picture
Upload folder using huggingface_hub
742d952 verified
from datetime import datetime
import os
import random
import torch
import torch.optim as optim
import torch.nn.functional as F
import Image
import ModelFormat
from StyleTransferLoss import StyleTransferLoss
import onnxruntime as rt
import cv2
from insightface.data import get_image as ins_get_image
from insightface.app import FaceAnalysis
import face_align
from StyleTransferModel_128 import StyleTransferModel
from torch.utils.tensorboard import SummaryWriter
inswapper_128_path = 'inswapper_128.onnx'
img_size = 128
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
inswapperInferenceSession = rt.InferenceSession(inswapper_128_path, providers=providers)
faceAnalysis = FaceAnalysis(name='buffalo_l')
faceAnalysis.prepare(ctx_id=0, det_size=(512, 512))
def get_device():
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
style_loss_fn = StyleTransferLoss().to(get_device())
def train(datasetDir, learning_rate=0.0001, model_path=None, outputModelFolder='', saveModelEachSteps = 1, stopAtSteps=None, logDir=None, previewDir=None, saveAs_onnx = False, resolutions = [128], enableDataAugmentation = False):
device = get_device()
print(f"Using device: {device}")
model = StyleTransferModel().to(device)
if model_path is not None:
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
print(f"Loaded model from {model_path}")
lastSteps = int(model_path.split('-')[-1].split('.')[0])
print(f"Resuming training from step {lastSteps}")
else:
lastSteps = 0
model.train()
model = model.to(device)
# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Initialize TensorBoard writer
if logDir is not None:
train_writer = SummaryWriter(os.path.join(logDir, "training"))
val_writer = SummaryWriter(os.path.join(logDir, "validation"))
steps = 0
image = os.listdir(datasetDir)
resolutionIndex = 0
# Training loop
while True:
start_time = datetime.now()
resolution = resolutions[resolutionIndex%len(resolutions)]
targetFaceIndex = random.randint(0, len(image)-1)
sourceFaceIndex = random.randint(0, len(image)-1)
target_img=cv2.imread(f"{datasetDir}/{image[targetFaceIndex]}")
if enableDataAugmentation and steps % 2 == 0:
target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
target_img = cv2.cvtColor(target_img, cv2.COLOR_GRAY2BGR)
faces = faceAnalysis.get(target_img)
if targetFaceIndex != sourceFaceIndex:
source_img = cv2.imread(f"{datasetDir}/{image[sourceFaceIndex]}")
faces2 = faceAnalysis.get(source_img)
else:
faces2 = faces
if len(faces) > 0 and len(faces2) > 0:
new_aligned_face, _ = face_align.norm_crop2(target_img, faces[0].kps, img_size)
blob = Image.getBlob(new_aligned_face)
latent = Image.getLatent(faces2[0])
else:
continue
if targetFaceIndex != sourceFaceIndex:
input = {inswapperInferenceSession.get_inputs()[0].name: blob,
inswapperInferenceSession.get_inputs()[1].name: latent}
expected_output = inswapperInferenceSession.run([inswapperInferenceSession.get_outputs()[0].name], input)[0]
else:
expected_output = blob
expected_output_tensor = torch.from_numpy(expected_output).to(device)
if resolution != 128:
new_aligned_face, _ = face_align.norm_crop2(target_img, faces[0].kps, resolution)
blob = Image.getBlob(new_aligned_face, (resolution, resolution))
latent_tensor = torch.from_numpy(latent).to(device)
target_input_tensor = torch.from_numpy(blob).to(device)
optimizer.zero_grad()
output = model(target_input_tensor, latent_tensor)
if (resolution != 128):
output = F.interpolate(output, size=(128, 128), mode='bilinear', align_corners=False)
content_loss, identity_loss = style_loss_fn(output, expected_output_tensor)
loss = content_loss
if identity_loss is not None:
loss +=identity_loss
loss.backward()
optimizer.step()
steps += 1
totalSteps = steps + lastSteps
if logDir is not None:
train_writer.add_scalar("Loss/total", loss.item(), totalSteps)
train_writer.add_scalar("Loss/content_loss", content_loss.item(), totalSteps)
if identity_loss is not None:
train_writer.add_scalar("Loss/identity_loss", identity_loss.item(), totalSteps)
elapsed_time = datetime.now() - start_time
print(f"Total Steps: {totalSteps}, Step: {steps}, Loss: {loss.item():.4f}, Elapsed time: {elapsed_time}")
if steps % saveModelEachSteps == 0:
outputModelPath = f"reswapper-{totalSteps}.pth"
if outputModelFolder != '':
outputModelPath = f"{outputModelFolder}/{outputModelPath}"
saveModel(model, outputModelPath)
validation_total_loss, validation_content_loss, validation_identity_loss, swapped_face, swapped_face_256 = validate(outputModelPath)
if previewDir is not None:
cv2.imwrite(f"{previewDir}/{totalSteps}.jpg", swapped_face)
cv2.imwrite(f"{previewDir}/{totalSteps}_256.jpg", swapped_face_256)
if logDir is not None:
val_writer.add_scalar("Loss/total", validation_total_loss.item(), totalSteps)
val_writer.add_scalar("Loss/content_loss", validation_content_loss.item(), totalSteps)
if validation_identity_loss is not None:
val_writer.add_scalar("Loss/identity_loss", validation_identity_loss.item(), totalSteps)
if saveAs_onnx :
ModelFormat.save_as_onnx_model(outputModelPath)
if stopAtSteps is not None and steps == stopAtSteps:
exit()
resolutionIndex += 1
def saveModel(model, outputModelPath):
torch.save(model.state_dict(), outputModelPath)
def load_model(model_path):
device = get_device()
model = StyleTransferModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model.eval()
return model
def swap_face(model, target_face, source_face_latent):
device = get_device()
target_tensor = torch.from_numpy(target_face).to(device)
source_tensor = torch.from_numpy(source_face_latent).to(device)
with torch.no_grad():
swapped_tensor = model(target_tensor, source_tensor)
swapped_face = Image.postprocess_face(swapped_tensor)
return swapped_face, swapped_tensor
# test image
test_img = ins_get_image('t1')
test_faces = faceAnalysis.get(test_img)
test_faces = sorted(test_faces, key = lambda x : x.bbox[0])
test_target_face, _ = face_align.norm_crop2(test_img, test_faces[0].kps, img_size)
test_target_face = Image.getBlob(test_target_face)
test_l = Image.getLatent(test_faces[2])
test_target_face_256, _ = face_align.norm_crop2(test_img, test_faces[0].kps, 256)
test_target_face_256 = Image.getBlob(test_target_face_256, (256, 256))
test_input = {inswapperInferenceSession.get_inputs()[0].name: test_target_face,
inswapperInferenceSession.get_inputs()[1].name: test_l}
test_inswapperOutput = inswapperInferenceSession.run([inswapperInferenceSession.get_outputs()[0].name], test_input)[0]
def validate(modelPath):
model = load_model(modelPath)
swapped_face, swapped_tensor= swap_face(model, test_target_face, test_l)
swapped_face_256, _= swap_face(model, test_target_face_256, test_l)
validation_content_loss, validation_identity_loss = style_loss_fn(swapped_tensor, torch.from_numpy(test_inswapperOutput).to(get_device()))
validation_total_loss = validation_content_loss
if validation_identity_loss is not None:
validation_total_loss += validation_identity_loss
return validation_total_loss, validation_content_loss, validation_identity_loss, swapped_face, swapped_face_256
def main():
outputModelFolder = "model"
modelPath = None
# modelPath = f"{outputModelFolder}/reswapper-<step>.pth"
logDir = "training/log"
previewDir = "training/preview"
datasetDir = "FFHQ"
os.makedirs(outputModelFolder, exist_ok=True)
os.makedirs(previewDir, exist_ok=True)
train(
datasetDir=datasetDir,
model_path=modelPath,
learning_rate=0.0001,
# resolutions = [128, 256],
# enableDataAugmentation=True,
outputModelFolder=outputModelFolder,
saveModelEachSteps = 1000,
stopAtSteps = 70000,
logDir=f"{logDir}/{datetime.now().strftime('%Y%m%d %H%M%S')}",
previewDir=previewDir)
if __name__ == "__main__":
main()