import torch import cv2 import os import numpy as np import shutil from models.anime_gan import GeneratorV1 from models.anime_gan_v2 import GeneratorV2 from models.anime_gan_v3 import GeneratorV3 from utils.common import load_checkpoint, RELEASED_WEIGHTS from utils.image_processing import resize_image, normalize_input, denormalize_input from utils import read_image, is_image_file from tqdm import tqdm # from torch.cuda.amp import autocast try: import matplotlib.pyplot as plt except ImportError: plt = None try: import moviepy.video.io.ffmpeg_writer as ffmpeg_writer from moviepy.video.io.VideoFileClip import VideoFileClip except ImportError: ffmpeg_writer = None VideoFileClip = None VALID_FORMATS = { 'jpeg', 'jpg', 'jpe', 'png', 'bmp', } def auto_load_weight(weight, version=None, map_location=None): """Auto load Generator version from weight.""" weight_name = os.path.basename(weight).lower() if version is not None: version = version.lower() assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist" # If version is provided, use it. cls = { "v1": GeneratorV1, "v2": GeneratorV2, "v3": GeneratorV3 }[version] else: # Try to get class by name of weight file # For convenenice, weight should start with classname # e.g: Generatorv2_{anything}.pt if weight_name in RELEASED_WEIGHTS: version = RELEASED_WEIGHTS[weight_name][0] return auto_load_weight(weight, version=version, map_location=map_location) elif weight_name.startswith("generatorv2"): cls = GeneratorV2 elif weight_name.startswith("generatorv3"): cls = GeneratorV3 elif weight_name.startswith("generator"): cls = GeneratorV1 else: raise ValueError((f"Can not get Model from {weight_name}, " "you might need to explicitly specify version")) model = cls() load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location) model.eval() return model class Predictor: def __init__(self, weight='hayao', device='cpu', amp=True): # if not torch.cuda.is_available(): # device = 'cpu' # # Amp not working on cpu # amp = False self.amp = False # Automatic Mixed Precision #self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' self.device_type = 'cpu' self.device = torch.device(device) self.G = auto_load_weight(weight, map_location=device) self.G.to(self.device) def transform_and_show( self, image_path, figsize=(18, 10), save_path=None ): image = resize_image(read_image(image_path)) anime_img = self.transform(image) anime_img = anime_img.astype('uint8') fig = plt.figure(figsize=figsize) fig.add_subplot(1, 2, 1) # plt.title("Input") plt.imshow(image) plt.axis('off') fig.add_subplot(1, 2, 2) # plt.title("Anime style") plt.imshow(anime_img[0]) plt.axis('off') plt.tight_layout() plt.show() if save_path is not None: plt.savefig(save_path) def transform(self, image, denorm=True): ''' Transform a image to animation @Arguments: - image: np.array, shape = (Batch, width, height, channels) @Returns: - anime version of image: np.array ''' with torch.no_grad(): image = self.preprocess_images(image) # image = image.to(self.device) # with autocast(self.device_type, enabled=self.amp): # print(image.dtype, self.G) fake = self.G(image) fake = fake.detach().cpu().numpy() # Channel last fake = fake.transpose(0, 2, 3, 1) if denorm: fake = denormalize_input(fake, dtype=np.uint8) return fake def transform_image(self,image): # if not is_image_file(save_path): # raise ValueError(f"{save_path} is not valid") # image = read_image(file_path) # # if image is None: # raise ValueError(f"Could not get image from {file_path}") anime_img = self.transform(resize_image(image))[0] return anime_img # cv2.imwrite(save_path, anime_img[..., ::-1]) # print(f"Anime image saved to {save_path}") def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)): ''' Read all images from img_dir, transform and write the result to dest_dir ''' os.makedirs(dest_dir, exist_ok=True) files = os.listdir(img_dir) files = [f for f in files if self.is_valid_file(f)] print(f'Found {len(files)} images in {img_dir}') if max_images: files = files[:max_images] for fname in tqdm(files): image = cv2.imread(os.path.join(img_dir, fname))[:,:,::-1] image = resize_image(image) anime_img = self.transform(image)[0] ext = fname.split('.')[-1] fname = fname.replace(f'.{ext}', '') cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1]) def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0): end = end or None # if not os.path.isfile(input_path): # raise FileNotFoundError(f'{input_path} does not exist') # output_dir = "/".join(output_path.split("/")[:-1]) # os.makedirs(output_dir, exist_ok=True) # is_gg_drive = '/drive/' in output_path # temp_file = '' # if is_gg_drive: # temp_file = f'tmp_anime.{output_path.split(".")[-1]}' def transform_and_write(frames, count, writer): anime_images = self.transform(frames) for i in range(count): img = np.clip(anime_images[i], 0, 255).astype(np.uint8) writer.write(img) video_capture = cv2.VideoCapture(input_path) frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(video_capture.get(cv2.CAP_PROP_FPS)) frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) if start or end: start_frame = int(start * fps) end_frame = int(end * fps) if end else frame_count video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) frame_count = end_frame - start_frame video_writer = cv2.VideoWriter( output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})') batch_shape = (batch_size, frame_height, frame_width, 3) frames = np.zeros(batch_shape, dtype=np.uint8) frame_idx = 0 try: for _ in tqdm(range(frame_count)): ret, frame = video_capture.read() if not ret: break frames[frame_idx] = frame frame_idx += 1 if frame_idx == batch_size: transform_and_write(frames, frame_idx, video_writer) frame_idx = 0 except Exception as e: print(e) finally: video_capture.release() video_writer.release() # if temp_file: # shutil.move(temp_file, output_path) # print(f'Animation video saved to {output_path}') def transform_video1(self, video, batch_size, start, end): #end = end or None # if not os.path.isfile(input_path): # raise FileNotFoundError(f'{input_path} does not exist') # output_dir = "/".join(output_path.split("/")[:-1]) # os.makedirs(output_dir, exist_ok=True) # is_gg_drive = '/drive/' in output_path # temp_file = '' # if is_gg_drive: # temp_file = f'tmp_anime.{output_path.split(".")[-1]}' # def transform_and_save(self, frames, count): # transformed_frames = [] # anime_images = self.transform(frames) # for i in range(count): # img = np.clip(anime_images[i], 0, 255).astype(np.uint8) # transformed_frames.append(img) # return transformed_frames def transform_and_write(frames, count, video_buffer): anime_images = self.transform(frames) for i in range(count): img = np.clip(anime_images[i], 0, 255).astype(np.uint8) success, encoded_image = cv2.imencode('.jpg', img) if success: video_buffer.append(encoded_image.tobytes()) video_capture = cv2.VideoCapture(video) frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(video_capture.get(cv2.CAP_PROP_FPS)) frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) print(f'Transforming video {frame_count} frames, size: ({frame_width}, {frame_height})') if start or end: start_frame = int(start * fps) end_frame = int(end * fps) if end else frame_count video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) frame_count = end_frame - start_frame # frame_count = len(video_frames) # transformed_video_frames = [] video_buffer = [] # batch_shape = (batch_size) + video_frames[0].shape # frames = np.zeros(batch_shape, dtype=np.uint8) # frame_idx = 0 batch_shape = (batch_size, frame_height, frame_width, 3) frames = np.zeros(batch_shape, dtype=np.uint8) frame_idx = 0 try: for _ in range(frame_count): ret, frame = video_capture.read() if not ret: break frames[frame_idx] = frame frame_idx += 1 if frame_idx == batch_size: transform_and_write(frames, frame_idx, video_buffer) frame_idx = 0 except Exception as e: print(e) finally: video_capture.release() return video_buffer def preprocess_images(self, images): ''' Preprocess image for inference @Arguments: - images: np.ndarray @Returns - images: torch.tensor ''' images = images.astype(np.float32) # Normalize to [-1, 1] images = normalize_input(images) images = torch.from_numpy(images) images = images.to(self.device) # Add batch dim if len(images.shape) == 3: images = images.unsqueeze(0) # channel first images = images.permute(0, 3, 1, 2) return images @staticmethod def is_valid_file(fname): ext = fname.split('.')[-1] return ext in VALID_FORMATS