cartoonize / inference.py
YANGYYYY's picture
Update inference.py
e567105 verified
raw
history blame
11.4 kB
import torch
import cv2
import os
import numpy as np
import shutil
from models.anime_gan import GeneratorV1
from models.anime_gan_v2 import GeneratorV2
from models.anime_gan_v3 import GeneratorV3
from utils.common import load_checkpoint, RELEASED_WEIGHTS
from utils.image_processing import resize_image, normalize_input, denormalize_input
from utils import read_image, is_image_file
from tqdm import tqdm
# from torch.cuda.amp import autocast
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
try:
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
from moviepy.video.io.VideoFileClip import VideoFileClip
except ImportError:
ffmpeg_writer = None
VideoFileClip = None
VALID_FORMATS = {
'jpeg', 'jpg', 'jpe',
'png', 'bmp',
}
def auto_load_weight(weight, version=None, map_location=None):
"""Auto load Generator version from weight."""
weight_name = os.path.basename(weight).lower()
if version is not None:
version = version.lower()
assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist"
# If version is provided, use it.
cls = {
"v1": GeneratorV1,
"v2": GeneratorV2,
"v3": GeneratorV3
}[version]
else:
# Try to get class by name of weight file
# For convenenice, weight should start with classname
# e.g: Generatorv2_{anything}.pt
if weight_name in RELEASED_WEIGHTS:
version = RELEASED_WEIGHTS[weight_name][0]
return auto_load_weight(weight, version=version, map_location=map_location)
elif weight_name.startswith("generatorv2"):
cls = GeneratorV2
elif weight_name.startswith("generatorv3"):
cls = GeneratorV3
elif weight_name.startswith("generator"):
cls = GeneratorV1
else:
raise ValueError((f"Can not get Model from {weight_name}, "
"you might need to explicitly specify version"))
model = cls()
load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location)
model.eval()
return model
class Predictor:
def __init__(self, weight='hayao', device='cpu', amp=True):
# if not torch.cuda.is_available():
# device = 'cpu'
# # Amp not working on cpu
# amp = False
self.amp = False # Automatic Mixed Precision
#self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
self.device_type = 'cpu'
self.device = torch.device(device)
self.G = auto_load_weight(weight, map_location=device)
self.G.to(self.device)
def transform_and_show(
self,
image_path,
figsize=(18, 10),
save_path=None
):
image = resize_image(read_image(image_path))
anime_img = self.transform(image)
anime_img = anime_img.astype('uint8')
fig = plt.figure(figsize=figsize)
fig.add_subplot(1, 2, 1)
# plt.title("Input")
plt.imshow(image)
plt.axis('off')
fig.add_subplot(1, 2, 2)
# plt.title("Anime style")
plt.imshow(anime_img[0])
plt.axis('off')
plt.tight_layout()
plt.show()
if save_path is not None:
plt.savefig(save_path)
def transform(self, image, denorm=True):
'''
Transform a image to animation
@Arguments:
- image: np.array, shape = (Batch, width, height, channels)
@Returns:
- anime version of image: np.array
'''
with torch.no_grad():
image = self.preprocess_images(image)
# image = image.to(self.device)
# with autocast(self.device_type, enabled=self.amp):
# print(image.dtype, self.G)
fake = self.G(image)
fake = fake.detach().cpu().numpy()
# Channel last
fake = fake.transpose(0, 2, 3, 1)
if denorm:
fake = denormalize_input(fake, dtype=np.uint8)
return fake
def transform_image(self,image):
# if not is_image_file(save_path):
# raise ValueError(f"{save_path} is not valid")
# image = read_image(file_path)
#
# if image is None:
# raise ValueError(f"Could not get image from {file_path}")
anime_img = self.transform(resize_image(image))[0]
return anime_img
# cv2.imwrite(save_path, anime_img[..., ::-1])
# print(f"Anime image saved to {save_path}")
def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)):
'''
Read all images from img_dir, transform and write the result
to dest_dir
'''
os.makedirs(dest_dir, exist_ok=True)
files = os.listdir(img_dir)
files = [f for f in files if self.is_valid_file(f)]
print(f'Found {len(files)} images in {img_dir}')
if max_images:
files = files[:max_images]
for fname in tqdm(files):
image = cv2.imread(os.path.join(img_dir, fname))[:,:,::-1]
image = resize_image(image)
anime_img = self.transform(image)[0]
ext = fname.split('.')[-1]
fname = fname.replace(f'.{ext}', '')
cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1])
def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
end = end or None
# if not os.path.isfile(input_path):
# raise FileNotFoundError(f'{input_path} does not exist')
# output_dir = "/".join(output_path.split("/")[:-1])
# os.makedirs(output_dir, exist_ok=True)
# is_gg_drive = '/drive/' in output_path
# temp_file = ''
# if is_gg_drive:
# temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
def transform_and_write(frames, count, writer):
anime_images = self.transform(frames)
for i in range(count):
img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
writer.write(img)
video_capture = cv2.VideoCapture(input_path)
frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(video_capture.get(cv2.CAP_PROP_FPS))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
if start or end:
start_frame = int(start * fps)
end_frame = int(end * fps) if end else frame_count
video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
frame_count = end_frame - start_frame
video_writer = cv2.VideoWriter(
output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})')
batch_shape = (batch_size, frame_height, frame_width, 3)
frames = np.zeros(batch_shape, dtype=np.uint8)
frame_idx = 0
try:
for _ in tqdm(range(frame_count)):
ret, frame = video_capture.read()
if not ret:
break
frames[frame_idx] = frame
frame_idx += 1
if frame_idx == batch_size:
transform_and_write(frames, frame_idx, video_writer)
frame_idx = 0
except Exception as e:
print(e)
finally:
video_capture.release()
video_writer.release()
# if temp_file:
# shutil.move(temp_file, output_path)
# print(f'Animation video saved to {output_path}')
def transform_video1(self, video, batch_size, start, end):
#end = end or None
# if not os.path.isfile(input_path):
# raise FileNotFoundError(f'{input_path} does not exist')
# output_dir = "/".join(output_path.split("/")[:-1])
# os.makedirs(output_dir, exist_ok=True)
# is_gg_drive = '/drive/' in output_path
# temp_file = ''
# if is_gg_drive:
# temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
# def transform_and_save(self, frames, count):
# transformed_frames = []
# anime_images = self.transform(frames)
# for i in range(count):
# img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
# transformed_frames.append(img)
# return transformed_frames
def transform_and_write(frames, count, video_buffer):
anime_images = self.transform(frames)
for i in range(count):
img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
success, encoded_image = cv2.imencode('.jpg', img)
if success:
video_buffer.append(encoded_image.tobytes())
video_capture = cv2.VideoCapture(video)
frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(video_capture.get(cv2.CAP_PROP_FPS))
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
print(f'Transforming video {frame_count} frames, size: ({frame_width}, {frame_height})')
if start or end:
start_frame = int(start * fps)
end_frame = int(end * fps) if end else frame_count
video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
frame_count = end_frame - start_frame
# frame_count = len(video_frames)
# transformed_video_frames = []
video_buffer = []
# batch_shape = (batch_size) + video_frames[0].shape
# frames = np.zeros(batch_shape, dtype=np.uint8)
# frame_idx = 0
batch_shape = (batch_size, frame_height, frame_width, 3)
frames = np.zeros(batch_shape, dtype=np.uint8)
frame_idx = 0
try:
for _ in range(frame_count):
ret, frame = video_capture.read()
if not ret:
break
frames[frame_idx] = frame
frame_idx += 1
if frame_idx == batch_size:
transform_and_write(frames, frame_idx, video_buffer)
frame_idx = 0
except Exception as e:
print(e)
finally:
video_capture.release()
return video_buffer
def preprocess_images(self, images):
'''
Preprocess image for inference
@Arguments:
- images: np.ndarray
@Returns
- images: torch.tensor
'''
images = images.astype(np.float32)
# Normalize to [-1, 1]
images = normalize_input(images)
images = torch.from_numpy(images)
images = images.to(self.device)
# Add batch dim
if len(images.shape) == 3:
images = images.unsqueeze(0)
# channel first
images = images.permute(0, 3, 1, 2)
return images
@staticmethod
def is_valid_file(fname):
ext = fname.split('.')[-1]
return ext in VALID_FORMATS