Daniel Verdu
first commit in hf_spaces
9e08039
raw
history blame
16.6 kB
from fastai.core import *
from fastai.vision import *
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from .filters import IFilter, MasterFilter, ColorizerFilter
from .generators import gen_inference_deep, gen_inference_wide
# from tensorboardX import SummaryWriter
from scipy import misc
from PIL import Image
# import ffmpeg
# import youtube_dl
import gc
import requests
from io import BytesIO
import base64
# from IPython import display as ipythondisplay
# from IPython.display import HTML
# from IPython.display import Image as ipythonimage
import cv2
# # adapted from https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/
# def get_watermarked(pil_image: Image) -> Image:
# try:
# image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
# (h, w) = image.shape[:2]
# image = np.dstack([image, np.ones((h, w), dtype="uint8") * 255])
# pct = 0.05
# full_watermark = cv2.imread(
# './resource_images/watermark.png', cv2.IMREAD_UNCHANGED
# )
# (fwH, fwW) = full_watermark.shape[:2]
# wH = int(pct * h)
# wW = int((pct * h / fwH) * fwW)
# watermark = cv2.resize(full_watermark, (wH, wW), interpolation=cv2.INTER_AREA)
# overlay = np.zeros((h, w, 4), dtype="uint8")
# (wH, wW) = watermark.shape[:2]
# overlay[h - wH - 10 : h - 10, 10 : 10 + wW] = watermark
# # blend the two images together using transparent overlays
# output = image.copy()
# cv2.addWeighted(overlay, 0.5, output, 1.0, 0, output)
# rgb_image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
# final_image = Image.fromarray(rgb_image)
# return final_image
# except:
# # Don't want this to crash everything, so let's just not watermark the image for now.
# return pil_image
class ModelImageVisualizer:
def __init__(self, filter: IFilter, results_dir: str = None):
self.filter = filter
self.results_dir = None if results_dir is None else Path(results_dir)
self.results_dir.mkdir(parents=True, exist_ok=True)
def _clean_mem(self):
torch.cuda.empty_cache()
# gc.collect()
def _open_pil_image(self, path: Path) -> Image:
return PIL.Image.open(path).convert('RGB')
def _get_image_from_url(self, url: str) -> Image:
response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
return img
def plot_transformed_image_from_url(
self,
url: str,
path: str = 'test_images/image.png',
results_dir:Path = None,
figsize: (int, int) = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
watermarked: bool = True,
) -> Path:
img = self._get_image_from_url(url)
img.save(path)
return self.plot_transformed_image(
path=path,
results_dir=results_dir,
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
compare=compare,
post_process = post_process,
watermarked=watermarked,
)
def plot_transformed_image(
self,
path: str,
results_dir:Path = None,
figsize: (int, int) = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
watermarked: bool = True,
) -> Path:
path = Path(path)
if results_dir is None:
results_dir = Path(self.results_dir)
result = self.get_transformed_image(
path, render_factor, post_process=post_process,watermarked=watermarked
)
orig = self._open_pil_image(path)
if compare:
self._plot_comparison(
figsize, render_factor, display_render_factor, orig, result
)
else:
self._plot_solo(figsize, render_factor, display_render_factor, result)
orig.close()
result_path = self._save_result_image(path, result, results_dir=results_dir)
result.close()
return result_path
def plot_transformed_pil_image(
self,
input_image: Image,
figsize: (int, int) = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
) -> Image:
result = self.get_transformed_pil_image(
input_image, render_factor, post_process=post_process
)
if compare:
self._plot_comparison(
figsize, render_factor, display_render_factor, input_image, result
)
else:
self._plot_solo(figsize, render_factor, display_render_factor, result)
return result
def _plot_comparison(
self,
figsize: (int, int),
render_factor: int,
display_render_factor: bool,
orig: Image,
result: Image,
):
fig, axes = plt.subplots(1, 2, figsize=figsize)
self._plot_image(
orig,
axes=axes[0],
figsize=figsize,
render_factor=render_factor,
display_render_factor=False,
)
self._plot_image(
result,
axes=axes[1],
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
)
def _plot_solo(
self,
figsize: (int, int),
render_factor: int,
display_render_factor: bool,
result: Image,
):
fig, axes = plt.subplots(1, 1, figsize=figsize)
self._plot_image(
result,
axes=axes,
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
)
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
if results_dir is None:
results_dir = Path(self.results_dir)
result_path = results_dir / source_path.name
image.save(result_path)
return result_path
def get_transformed_image(
self, path: Path, render_factor: int = None, post_process: bool = True,
watermarked: bool = True,
) -> Image:
self._clean_mem()
orig_image = self._open_pil_image(path)
filtered_image = self.filter.filter(
orig_image, orig_image, render_factor=render_factor,post_process=post_process
)
# if watermarked:
# return get_watermarked(filtered_image)
return filtered_image
def get_transformed_pil_image(
self, input_image: Image, render_factor: int = None, post_process: bool = True,
) -> Image:
self._clean_mem()
filtered_image = self.filter.filter(
input_image, input_image, render_factor=render_factor,post_process=post_process
)
return filtered_image
def _plot_image(
self,
image: Image,
render_factor: int,
axes: Axes = None,
figsize=(20, 20),
display_render_factor = False,
):
if axes is None:
_, axes = plt.subplots(figsize=figsize)
axes.imshow(np.asarray(image) / 255)
axes.axis('off')
if render_factor is not None and display_render_factor:
plt.text(
10,
10,
'render_factor: ' + str(render_factor),
color='white',
backgroundcolor='black',
)
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> (int, int):
columns = min(num_images, max_columns)
rows = num_images // columns
rows = rows if rows * columns == num_images else rows + 1
return rows, columns
# class VideoColorizer:
# def __init__(self, vis: ModelImageVisualizer):
# self.vis = vis
# workfolder = Path('./video')
# self.source_folder = workfolder / "source"
# self.bwframes_root = workfolder / "bwframes"
# self.audio_root = workfolder / "audio"
# self.colorframes_root = workfolder / "colorframes"
# self.result_folder = workfolder / "result"
# def _purge_images(self, dir):
# for f in os.listdir(dir):
# if re.search('.*?\.jpg', f):
# os.remove(os.path.join(dir, f))
# def _get_fps(self, source_path: Path) -> str:
# probe = ffmpeg.probe(str(source_path))
# stream_data = next(
# (stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
# None,
# )
# return stream_data['avg_frame_rate']
# def _download_video_from_url(self, source_url, source_path: Path):
# if source_path.exists():
# source_path.unlink()
# ydl_opts = {
# 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
# 'outtmpl': str(source_path),
# 'retries': 30,
# 'fragment-retries': 30
# }
# with youtube_dl.YoutubeDL(ydl_opts) as ydl:
# ydl.download([source_url])
# def _extract_raw_frames(self, source_path: Path):
# bwframes_folder = self.bwframes_root / (source_path.stem)
# bwframe_path_template = str(bwframes_folder / '%5d.jpg')
# bwframes_folder.mkdir(parents=True, exist_ok=True)
# self._purge_images(bwframes_folder)
# ffmpeg.input(str(source_path)).output(
# str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0
# ).run(capture_stdout=True)
# def _colorize_raw_frames(
# self, source_path: Path, render_factor: int = None, post_process: bool = True,
# watermarked: bool = True,
# ):
# colorframes_folder = self.colorframes_root / (source_path.stem)
# colorframes_folder.mkdir(parents=True, exist_ok=True)
# self._purge_images(colorframes_folder)
# bwframes_folder = self.bwframes_root / (source_path.stem)
# for img in progress_bar(os.listdir(str(bwframes_folder))):
# img_path = bwframes_folder / img
# if os.path.isfile(str(img_path)):
# color_image = self.vis.get_transformed_image(
# str(img_path), render_factor=render_factor, post_process=post_process,watermarked=watermarked
# )
# color_image.save(str(colorframes_folder / img))
# def _build_video(self, source_path: Path) -> Path:
# colorized_path = self.result_folder / (
# source_path.name.replace('.mp4', '_no_audio.mp4')
# )
# colorframes_folder = self.colorframes_root / (source_path.stem)
# colorframes_path_template = str(colorframes_folder / '%5d.jpg')
# colorized_path.parent.mkdir(parents=True, exist_ok=True)
# if colorized_path.exists():
# colorized_path.unlink()
# fps = self._get_fps(source_path)
# ffmpeg.input(
# str(colorframes_path_template),
# format='image2',
# vcodec='mjpeg',
# framerate=fps,
# ).output(str(colorized_path), crf=17, vcodec='libx264').run(capture_stdout=True)
# result_path = self.result_folder / source_path.name
# if result_path.exists():
# result_path.unlink()
# # making copy of non-audio version in case adding back audio doesn't apply or fails.
# shutil.copyfile(str(colorized_path), str(result_path))
# # adding back sound here
# audio_file = Path(str(source_path).replace('.mp4', '.aac'))
# if audio_file.exists():
# audio_file.unlink()
# os.system(
# 'ffmpeg -y -i "'
# + str(source_path)
# + '" -vn -acodec copy "'
# + str(audio_file)
# + '"'
# )
# if audio_file.exists:
# os.system(
# 'ffmpeg -y -i "'
# + str(colorized_path)
# + '" -i "'
# + str(audio_file)
# + '" -shortest -c:v copy -c:a aac -b:a 256k "'
# + str(result_path)
# + '"'
# )
# print('Video created here: ' + str(result_path))
# return result_path
# def colorize_from_url(
# self,
# source_url,
# file_name: str,
# render_factor: int = None,
# post_process: bool = True,
# watermarked: bool = True,
# ) -> Path:
# source_path = self.source_folder / file_name
# self._download_video_from_url(source_url, source_path)
# return self._colorize_from_path(
# source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
# )
# def colorize_from_file_name(
# self, file_name: str, render_factor: int = None, watermarked: bool = True, post_process: bool = True,
# ) -> Path:
# source_path = self.source_folder / file_name
# return self._colorize_from_path(
# source_path, render_factor=render_factor, post_process=post_process,watermarked=watermarked
# )
# def _colorize_from_path(
# self, source_path: Path, render_factor: int = None, watermarked: bool = True, post_process: bool = True
# ) -> Path:
# if not source_path.exists():
# raise Exception(
# 'Video at path specfied, ' + str(source_path) + ' could not be found.'
# )
# self._extract_raw_frames(source_path)
# self._colorize_raw_frames(
# source_path, render_factor=render_factor,post_process=post_process,watermarked=watermarked
# )
# return self._build_video(source_path)
# def get_video_colorizer(render_factor: int = 21) -> VideoColorizer:
# return get_stable_video_colorizer(render_factor=render_factor)
# def get_artistic_video_colorizer(
# root_folder: Path = Path('./'),
# weights_name: str = 'ColorizeArtistic_gen',
# results_dir='result_images',
# render_factor: int = 35
# ) -> VideoColorizer:
# learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
# filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
# return VideoColorizer(vis)
# def get_stable_video_colorizer(
# root_folder: Path = Path('./'),
# weights_name: str = 'ColorizeVideo_gen',
# results_dir='result_images',
# render_factor: int = 21
# ) -> VideoColorizer:
# learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
# filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
# return VideoColorizer(vis)
def get_image_colorizer(
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
) -> ModelImageVisualizer:
if artistic:
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
else:
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
def get_stable_image_colorizer(
root_folder: Path = Path('./'),
weights_name: str = 'ColorizeStable_gen',
results_dir='result_images',
render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
return vis
def get_artistic_image_colorizer(
root_folder: Path = Path('./'),
weights_name: str = 'ColorizeArtistic_gen',
results_dir='result_images',
render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
return vis