|
import os |
|
import time |
|
import shutil |
|
|
|
import torch |
|
import cv2 |
|
import numpy as np |
|
|
|
from models.anime_gan import GeneratorV1 |
|
from models.anime_gan_v2 import GeneratorV2 |
|
from models.anime_gan_v3 import GeneratorV3 |
|
from utils.common import load_checkpoint, RELEASED_WEIGHTS |
|
from utils.image_processing import resize_image, normalize_input, denormalize_input |
|
from utils import read_image, is_image_file, is_video_file |
|
from tqdm import tqdm |
|
from color_transfer import color_transfer_pytorch |
|
|
|
|
|
try: |
|
import matplotlib.pyplot as plt |
|
except ImportError: |
|
plt = None |
|
|
|
try: |
|
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer |
|
from moviepy.video.io.VideoFileClip import VideoFileClip |
|
except ImportError: |
|
ffmpeg_writer = None |
|
VideoFileClip = None |
|
|
|
|
|
def profile(func): |
|
def wrap(*args, **kwargs): |
|
started_at = time.time() |
|
result = func(*args, **kwargs) |
|
elapsed = time.time() - started_at |
|
print(f"Processed in {elapsed:.3f}s") |
|
return result |
|
return wrap |
|
|
|
|
|
def auto_load_weight(weight, version=None, map_location=None): |
|
"""Auto load Generator version from weight.""" |
|
project_dir = os.path.dirname(os.path.abspath(__file__)) |
|
cache_dir = os.path.join(project_dir, ".cache") |
|
weight_name = os.path.basename(weight) |
|
cached_weight = os.path.join(cache_dir, weight_name) |
|
print(project_dir, cache_dir, weight, weight_name, cached_weight) |
|
|
|
|
|
if os.path.exists(cached_weight): |
|
weight = cached_weight |
|
|
|
if version is not None: |
|
version = version.lower() |
|
assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist" |
|
|
|
cls = { |
|
"v1": GeneratorV1, |
|
"v2": GeneratorV2, |
|
"v3": GeneratorV3 |
|
}[version] |
|
else: |
|
|
|
|
|
|
|
if weight_name in RELEASED_WEIGHTS: |
|
version = RELEASED_WEIGHTS[weight_name][0] |
|
return auto_load_weight(weight, version=version, map_location=map_location) |
|
|
|
elif weight_name.startswith("GeneratorV2"): |
|
cls = GeneratorV2 |
|
elif weight_name.startswith("generatorv3"): |
|
cls = GeneratorV3 |
|
elif weight_name.startswith("generator"): |
|
cls = GeneratorV1 |
|
else: |
|
raise ValueError((f"Can not get Model from {weight_name}, " |
|
"you might need to explicitly specify version")) |
|
model = cls() |
|
load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location) |
|
model.eval() |
|
return model |
|
|
|
|
|
class Predictor: |
|
""" |
|
Generic class for transfering Image to anime like image. |
|
""" |
|
def __init__( |
|
self, |
|
weight='hayao', |
|
device='cuda', |
|
amp=True, |
|
retain_color=False, |
|
imgsz=None, |
|
): |
|
if not torch.cuda.is_available(): |
|
device = 'cpu' |
|
|
|
amp = False |
|
print("Use CPU device") |
|
else: |
|
print(f"Use GPU {torch.cuda.get_device_name()}") |
|
|
|
self.imgsz = imgsz |
|
self.retain_color = retain_color |
|
self.amp = amp |
|
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
|
self.device = torch.device(device) |
|
self.G = auto_load_weight(weight, map_location=device) |
|
self.G.to(self.device) |
|
|
|
def transform_and_show( |
|
self, |
|
image_path, |
|
figsize=(18, 10), |
|
save_path=None |
|
): |
|
image = resize_image(read_image(image_path)) |
|
anime_img = self.transform(image) |
|
anime_img = anime_img.astype('uint8') |
|
|
|
fig = plt.figure(figsize=figsize) |
|
fig.add_subplot(1, 2, 1) |
|
|
|
plt.imshow(image) |
|
plt.axis('off') |
|
fig.add_subplot(1, 2, 2) |
|
|
|
plt.imshow(anime_img[0]) |
|
plt.axis('off') |
|
plt.tight_layout() |
|
plt.show() |
|
if save_path is not None: |
|
plt.savefig(save_path) |
|
|
|
def transform(self, image, denorm=True): |
|
''' |
|
Transform a image to animation |
|
|
|
@Arguments: |
|
- image: np.array, shape = (Batch, width, height, channels) |
|
|
|
@Returns: |
|
- anime version of image: np.array |
|
''' |
|
with torch.no_grad(): |
|
image = self.preprocess_images(image) |
|
|
|
|
|
|
|
fake = self.G(image) |
|
|
|
if self.retain_color: |
|
fake = color_transfer_pytorch(fake, image) |
|
fake = (fake / 0.5) - 1.0 |
|
fake = fake.detach().cpu().numpy() |
|
|
|
fake = fake.transpose(0, 2, 3, 1) |
|
|
|
if denorm: |
|
fake = denormalize_input(fake, dtype=np.uint8) |
|
return fake |
|
|
|
def read_and_resize(self, path, max_size=1536): |
|
image = read_image(path) |
|
_, ext = os.path.splitext(path) |
|
h, w = image.shape[:2] |
|
if self.imgsz is not None: |
|
image = resize_image(image, width=self.imgsz) |
|
elif max(h, w) > max_size: |
|
print(f"Image {os.path.basename(path)} is too big ({h}x{w}), resize to max size {max_size}") |
|
image = resize_image( |
|
image, |
|
width=max_size if w > h else None, |
|
height=max_size if w < h else None, |
|
) |
|
cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1]) |
|
else: |
|
image = resize_image(image) |
|
|
|
|
|
|
|
return image |
|
|
|
@profile |
|
def transform_file(self, file_path, save_path): |
|
if not is_image_file(save_path): |
|
raise ValueError(f"{save_path} is not valid") |
|
|
|
image = self.read_and_resize(file_path) |
|
anime_img = self.transform(image)[0] |
|
cv2.imwrite(save_path, anime_img[..., ::-1]) |
|
print(f"Anime image saved to {save_path}") |
|
return anime_img |
|
|
|
@profile |
|
def transform_gif(self, file_path, save_path, batch_size=4): |
|
import imageio |
|
|
|
def _preprocess_gif(img): |
|
if img.shape[-1] == 4: |
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
|
return resize_image(img) |
|
|
|
images = imageio.mimread(file_path) |
|
images = np.stack([ |
|
_preprocess_gif(img) |
|
for img in images |
|
]) |
|
|
|
print(images.shape) |
|
|
|
anime_gif = np.zeros_like(images) |
|
|
|
for i in tqdm(range(0, len(images), batch_size)): |
|
end = i + batch_size |
|
anime_gif[i: end] = self.transform( |
|
images[i: end] |
|
) |
|
|
|
if end < len(images) - 1: |
|
|
|
print("LAST", images[end: ].shape) |
|
anime_gif[end:] = self.transform(images[end:]) |
|
|
|
print(anime_gif.shape) |
|
imageio.mimsave( |
|
save_path, |
|
anime_gif, |
|
|
|
) |
|
print(f"Anime image saved to {save_path}") |
|
|
|
@profile |
|
def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)): |
|
''' |
|
Read all images from img_dir, transform and write the result |
|
to dest_dir |
|
|
|
''' |
|
os.makedirs(dest_dir, exist_ok=True) |
|
|
|
files = os.listdir(img_dir) |
|
files = [f for f in files if is_image_file(f)] |
|
print(f'Found {len(files)} images in {img_dir}') |
|
|
|
if max_images: |
|
files = files[:max_images] |
|
|
|
bar = tqdm(files) |
|
for fname in bar: |
|
path = os.path.join(img_dir, fname) |
|
image = self.read_and_resize(path) |
|
anime_img = self.transform(image)[0] |
|
|
|
ext = fname.split('.')[-1] |
|
fname = fname.replace(f'.{ext}', '') |
|
cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1]) |
|
bar.set_description(f"{fname} {image.shape}") |
|
|
|
def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0): |
|
''' |
|
Transform a video to animation version |
|
https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21 |
|
''' |
|
if VideoFileClip is None: |
|
raise ImportError("moviepy is not installed, please install with `pip install moviepy>=1.0.3`") |
|
|
|
end = end or None |
|
|
|
if not os.path.isfile(input_path): |
|
raise FileNotFoundError(f'{input_path} does not exist') |
|
|
|
output_dir = os.path.dirname(output_path) |
|
if output_dir: |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
is_gg_drive = '/drive/' in output_path |
|
temp_file = '' |
|
|
|
if is_gg_drive: |
|
|
|
temp_file = f'tmp_anime.{output_path.split(".")[-1]}' |
|
|
|
def transform_and_write(frames, count, writer): |
|
anime_images = self.transform(frames) |
|
for i in range(0, count): |
|
img = np.clip(anime_images[i], 0, 255) |
|
writer.write_frame(img) |
|
|
|
video_clip = VideoFileClip(input_path, audio=False) |
|
if start or end: |
|
video_clip = video_clip.subclip(start, end) |
|
|
|
video_writer = ffmpeg_writer.FFMPEG_VideoWriter( |
|
temp_file or output_path, |
|
video_clip.size, video_clip.fps, |
|
codec="libx264", |
|
|
|
ffmpeg_params=None) |
|
|
|
total_frames = round(video_clip.fps * video_clip.duration) |
|
print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}') |
|
|
|
batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3) |
|
frame_count = 0 |
|
frames = np.zeros(batch_shape, dtype=np.float32) |
|
for frame in tqdm(video_clip.iter_frames(), total=total_frames): |
|
try: |
|
frames[frame_count] = frame |
|
frame_count += 1 |
|
if frame_count == batch_size: |
|
transform_and_write(frames, frame_count, video_writer) |
|
frame_count = 0 |
|
except Exception as e: |
|
print(e) |
|
break |
|
|
|
|
|
if frame_count != 0: |
|
transform_and_write(frames, frame_count, video_writer) |
|
|
|
if temp_file: |
|
|
|
shutil.move(temp_file, output_path) |
|
|
|
print(f'Animation video saved to {output_path}') |
|
video_writer.close() |
|
|
|
def preprocess_images(self, images): |
|
''' |
|
Preprocess image for inference |
|
|
|
@Arguments: |
|
- images: np.ndarray |
|
|
|
@Returns |
|
- images: torch.tensor |
|
''' |
|
images = images.astype(np.float32) |
|
|
|
|
|
images = normalize_input(images) |
|
images = torch.from_numpy(images) |
|
|
|
images = images.to(self.device) |
|
|
|
|
|
if len(images.shape) == 3: |
|
images = images.unsqueeze(0) |
|
|
|
|
|
images = images.permute(0, 3, 1, 2) |
|
|
|
return images |
|
|
|
|
|
def parse_args(): |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--weight', |
|
type=str, |
|
default="hayao:v2", |
|
help=f'Model weight, can be path or pretrained {tuple(RELEASED_WEIGHTS.keys())}' |
|
) |
|
parser.add_argument('--src', type=str, help='Source, can be directory contains images, image file or video file.') |
|
parser.add_argument('--device', type=str, default='cuda', help='Device, cuda or cpu') |
|
parser.add_argument('--imgsz', type=int, default=None, help='Resize image to specified size if provided') |
|
parser.add_argument('--out', type=str, default='inference_images', help='Output, can be directory or file') |
|
parser.add_argument( |
|
'--retain-color', |
|
action='store_true', |
|
help='If provided the generated image will retain original color of input image') |
|
|
|
parser.add_argument('--batch-size', type=int, default=4, help='Batch size when inference video') |
|
parser.add_argument('--start', type=int, default=0, help='Start time of video (second)') |
|
parser.add_argument('--end', type=int, default=0, help='End time of video (second), 0 if not set') |
|
|
|
return parser.parse_args() |
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
|
|
predictor = Predictor( |
|
args.weight, |
|
args.device, |
|
retain_color=args.retain_color, |
|
imgsz=args.imgsz, |
|
) |
|
|
|
if not os.path.exists(args.src): |
|
raise FileNotFoundError(args.src) |
|
|
|
if is_video_file(args.src): |
|
predictor.transform_video( |
|
args.src, |
|
args.out, |
|
args.batch_size, |
|
start=args.start, |
|
end=args.end |
|
) |
|
elif os.path.isdir(args.src): |
|
predictor.transform_in_dir(args.src, args.out) |
|
elif os.path.isfile(args.src): |
|
save_path = args.out |
|
if not is_image_file(args.out): |
|
os.makedirs(args.out, exist_ok=True) |
|
save_path = os.path.join(args.out, os.path.basename(args.src)) |
|
|
|
if args.src.endswith('.gif'): |
|
|
|
predictor.transform_gif(args.src, save_path, args.batch_size) |
|
else: |
|
predictor.transform_file(args.src, save_path) |
|
else: |
|
raise NotImplementedError(f"{args.src} is not supported") |
|
|