Spaces:
Running
Running
import torch | |
import cv2 | |
import os | |
import numpy as np | |
import shutil | |
from models.anime_gan import GeneratorV1 | |
from models.anime_gan_v2 import GeneratorV2 | |
from models.anime_gan_v3 import GeneratorV3 | |
from utils.common import load_checkpoint, RELEASED_WEIGHTS | |
from utils.image_processing import resize_image, normalize_input, denormalize_input | |
from utils import read_image, is_image_file | |
from tqdm import tqdm | |
# from torch.cuda.amp import autocast | |
try: | |
import matplotlib.pyplot as plt | |
except ImportError: | |
plt = None | |
try: | |
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer | |
from moviepy.video.io.VideoFileClip import VideoFileClip | |
except ImportError: | |
ffmpeg_writer = None | |
VideoFileClip = None | |
VALID_FORMATS = { | |
'jpeg', 'jpg', 'jpe', | |
'png', 'bmp', | |
} | |
def auto_load_weight(weight, version=None, map_location=None): | |
"""Auto load Generator version from weight.""" | |
weight_name = os.path.basename(weight).lower() | |
if version is not None: | |
version = version.lower() | |
assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist" | |
# If version is provided, use it. | |
cls = { | |
"v1": GeneratorV1, | |
"v2": GeneratorV2, | |
"v3": GeneratorV3 | |
}[version] | |
else: | |
# Try to get class by name of weight file | |
# For convenenice, weight should start with classname | |
# e.g: Generatorv2_{anything}.pt | |
if weight_name in RELEASED_WEIGHTS: | |
version = RELEASED_WEIGHTS[weight_name][0] | |
return auto_load_weight(weight, version=version, map_location=map_location) | |
elif weight_name.startswith("generatorv2"): | |
cls = GeneratorV2 | |
elif weight_name.startswith("generatorv3"): | |
cls = GeneratorV3 | |
elif weight_name.startswith("generator"): | |
cls = GeneratorV1 | |
else: | |
raise ValueError((f"Can not get Model from {weight_name}, " | |
"you might need to explicitly specify version")) | |
model = cls() | |
load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location) | |
model.eval() | |
return model | |
class Predictor: | |
def __init__(self, weight='hayao', device='cpu', amp=True): | |
# if not torch.cuda.is_available(): | |
# device = 'cpu' | |
# # Amp not working on cpu | |
# amp = False | |
self.amp = False # Automatic Mixed Precision | |
#self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' | |
self.device_type = 'cpu' | |
self.device = torch.device(device) | |
self.G = auto_load_weight(weight, map_location=device) | |
self.G.to(self.device) | |
def transform_and_show( | |
self, | |
image_path, | |
figsize=(18, 10), | |
save_path=None | |
): | |
image = resize_image(read_image(image_path)) | |
anime_img = self.transform(image) | |
anime_img = anime_img.astype('uint8') | |
fig = plt.figure(figsize=figsize) | |
fig.add_subplot(1, 2, 1) | |
# plt.title("Input") | |
plt.imshow(image) | |
plt.axis('off') | |
fig.add_subplot(1, 2, 2) | |
# plt.title("Anime style") | |
plt.imshow(anime_img[0]) | |
plt.axis('off') | |
plt.tight_layout() | |
plt.show() | |
if save_path is not None: | |
plt.savefig(save_path) | |
def transform(self, image, denorm=True): | |
''' | |
Transform a image to animation | |
@Arguments: | |
- image: np.array, shape = (Batch, width, height, channels) | |
@Returns: | |
- anime version of image: np.array | |
''' | |
with torch.no_grad(): | |
image = self.preprocess_images(image) | |
# image = image.to(self.device) | |
# with autocast(self.device_type, enabled=self.amp): | |
# print(image.dtype, self.G) | |
fake = self.G(image) | |
fake = fake.detach().cpu().numpy() | |
# Channel last | |
fake = fake.transpose(0, 2, 3, 1) | |
if denorm: | |
fake = denormalize_input(fake, dtype=np.uint8) | |
return fake | |
def transform_image(self,image): | |
# if not is_image_file(save_path): | |
# raise ValueError(f"{save_path} is not valid") | |
# image = read_image(file_path) | |
# | |
# if image is None: | |
# raise ValueError(f"Could not get image from {file_path}") | |
anime_img = self.transform(resize_image(image))[0] | |
return anime_img | |
# cv2.imwrite(save_path, anime_img[..., ::-1]) | |
# print(f"Anime image saved to {save_path}") | |
def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)): | |
''' | |
Read all images from img_dir, transform and write the result | |
to dest_dir | |
''' | |
os.makedirs(dest_dir, exist_ok=True) | |
files = os.listdir(img_dir) | |
files = [f for f in files if self.is_valid_file(f)] | |
print(f'Found {len(files)} images in {img_dir}') | |
if max_images: | |
files = files[:max_images] | |
for fname in tqdm(files): | |
image = cv2.imread(os.path.join(img_dir, fname))[:,:,::-1] | |
image = resize_image(image) | |
anime_img = self.transform(image)[0] | |
ext = fname.split('.')[-1] | |
fname = fname.replace(f'.{ext}', '') | |
cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1]) | |
def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0): | |
end = end or None | |
# if not os.path.isfile(input_path): | |
# raise FileNotFoundError(f'{input_path} does not exist') | |
# output_dir = "/".join(output_path.split("/")[:-1]) | |
# os.makedirs(output_dir, exist_ok=True) | |
# is_gg_drive = '/drive/' in output_path | |
# temp_file = '' | |
# if is_gg_drive: | |
# temp_file = f'tmp_anime.{output_path.split(".")[-1]}' | |
def transform_and_write(frames, count, writer): | |
anime_images = self.transform(frames) | |
for i in range(count): | |
img = np.clip(anime_images[i], 0, 255).astype(np.uint8) | |
writer.write(img) | |
video_capture = cv2.VideoCapture(input_path) | |
frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(video_capture.get(cv2.CAP_PROP_FPS)) | |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if start or end: | |
start_frame = int(start * fps) | |
end_frame = int(end * fps) if end else frame_count | |
video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
frame_count = end_frame - start_frame | |
video_writer = cv2.VideoWriter( | |
output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) | |
print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})') | |
batch_shape = (batch_size, frame_height, frame_width, 3) | |
frames = np.zeros(batch_shape, dtype=np.uint8) | |
frame_idx = 0 | |
try: | |
for _ in tqdm(range(frame_count)): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
frames[frame_idx] = frame | |
frame_idx += 1 | |
if frame_idx == batch_size: | |
transform_and_write(frames, frame_idx, video_writer) | |
frame_idx = 0 | |
except Exception as e: | |
print(e) | |
finally: | |
video_capture.release() | |
video_writer.release() | |
# if temp_file: | |
# shutil.move(temp_file, output_path) | |
# print(f'Animation video saved to {output_path}') | |
def transform_video1(self, video, batch_size, start, end): | |
#end = end or None | |
# if not os.path.isfile(input_path): | |
# raise FileNotFoundError(f'{input_path} does not exist') | |
# output_dir = "/".join(output_path.split("/")[:-1]) | |
# os.makedirs(output_dir, exist_ok=True) | |
# is_gg_drive = '/drive/' in output_path | |
# temp_file = '' | |
# if is_gg_drive: | |
# temp_file = f'tmp_anime.{output_path.split(".")[-1]}' | |
# def transform_and_save(self, frames, count): | |
# transformed_frames = [] | |
# anime_images = self.transform(frames) | |
# for i in range(count): | |
# img = np.clip(anime_images[i], 0, 255).astype(np.uint8) | |
# transformed_frames.append(img) | |
# return transformed_frames | |
def transform_and_write(frames, count, video_buffer): | |
anime_images = self.transform(frames) | |
for i in range(count): | |
img = np.clip(anime_images[i], 0, 255).astype(np.uint8) | |
success, encoded_image = cv2.imencode('.jpg', img) | |
if success: | |
video_buffer.append(encoded_image.tobytes()) | |
video_capture = cv2.VideoCapture(video) | |
frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(video_capture.get(cv2.CAP_PROP_FPS)) | |
frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
print(f'Transforming video {frame_count} frames, size: ({frame_width}, {frame_height})') | |
if start or end: | |
start_frame = int(start * fps) | |
end_frame = int(end * fps) if end else frame_count | |
video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
frame_count = end_frame - start_frame | |
# frame_count = len(video_frames) | |
# transformed_video_frames = [] | |
video_buffer = [] | |
# batch_shape = (batch_size) + video_frames[0].shape | |
# frames = np.zeros(batch_shape, dtype=np.uint8) | |
# frame_idx = 0 | |
batch_shape = (batch_size, frame_height, frame_width, 3) | |
frames = np.zeros(batch_shape, dtype=np.uint8) | |
frame_idx = 0 | |
try: | |
for _ in range(frame_count): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
frames[frame_idx] = frame | |
frame_idx += 1 | |
if frame_idx == batch_size: | |
transform_and_write(frames, frame_idx, video_buffer) | |
frame_idx = 0 | |
except Exception as e: | |
print(e) | |
finally: | |
video_capture.release() | |
return video_buffer | |
def preprocess_images(self, images): | |
''' | |
Preprocess image for inference | |
@Arguments: | |
- images: np.ndarray | |
@Returns | |
- images: torch.tensor | |
''' | |
images = images.astype(np.float32) | |
# Normalize to [-1, 1] | |
images = normalize_input(images) | |
images = torch.from_numpy(images) | |
images = images.to(self.device) | |
# Add batch dim | |
if len(images.shape) == 3: | |
images = images.unsqueeze(0) | |
# channel first | |
images = images.permute(0, 3, 1, 2) | |
return images | |
def is_valid_file(fname): | |
ext = fname.split('.')[-1] | |
return ext in VALID_FORMATS | |