File size: 2,065 Bytes
9bdccea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from PIL import Image
import cv2 as cv
import torch
from RealESRGAN import RealESRGAN
import tempfile
import numpy as np
import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
if img is None:
raise Exception("Image not uploaded")
width, height = img.size
if width >= 5000 or height >= 5000:
raise Exception("The image is too large.")
model = RealESRGAN(device, scale=size_modifier)
model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
result = model.predict(img.convert('RGB'))
print(f"Image size ({device}): {size_modifier} ... OK")
return result
def infer_video(video_filepath: str, size_modifier: int) -> str:
model = RealESRGAN(device, scale=size_modifier)
model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
cap = cv.VideoCapture(video_filepath)
tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
vid_output = tmpfile.name
tmpfile.close()
vid_writer = cv.VideoWriter(
vid_output,
fourcc=cv.VideoWriter.fourcc(*'mp4v'),
fps=cap.get(cv.CAP_PROP_FPS),
frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
)
n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
# while cap.isOpened():
for _ in tqdm.tqdm(range(n_frames)):
ret, frame = cap.read()
if not ret:
break
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
upscaled_frame = model.predict(frame.convert('RGB'))
upscaled_frame = np.array(upscaled_frame)
upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
print(upscaled_frame.shape)
vid_writer.write(upscaled_frame)
vid_writer.release()
print(f"Video file : {video_filepath}")
return vid_output
|