FaceCropAnime / app.py
Carzit's picture
Upload app.py
a6bb6f2 verified
import os
from pathlib import Path
from PIL import Image
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import (non_max_suppression, scale_coords, xyxy2xywh)
from utils.torch_utils import select_device
import gradio as gr
import huggingface_hub
from crop import crop
class FaceCrop:
def __init__(self):
self.device = select_device()
self.half = self.device.type != 'cpu'
self.results = []
def load_dataset(self, source):
self.source = source
self.dataset = LoadImages(source)
print(f'Successfully load {source}')
def load_model(self, model):
self.model = attempt_load(model, map_location=self.device)
if self.half:
self.model.half()
print(f'Successfully load model weights from {model}')
def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5):
self.target_size = target_size
self.mode = mode
self.face_ratio = face_ratio
self.threshold = threshold
def info(self):
attributes = dir(self)
for attribute in attributes:
if not attribute.startswith('__') and not callable(getattr(self, attribute)):
value = getattr(self, attribute)
print(attribute, " = ", value)
def process(self):
for path, img, im0s, vid_cap in self.dataset:
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
pred = self.model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred)
# Process detections
for i, det in enumerate(pred): # detections per image
p, s, im0 = path, '', im0s
#txt_path = str(Path(out) / Path(p).stem)
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in det:
if conf > 0.6: # Write to file
x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
self.results.append(crop(self.source, (x, y), mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold))
def run(img, mode, width, height, face_ratio, threshold):
face_crop_pipeline.results = []
face_crop_pipeline.load_dataset(img)
face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height), face_ratio=face_ratio, threshold=threshold)
face_crop_pipeline.process()
return face_crop_pipeline.results
if __name__ == '__main__':
model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolov5x_anime.pt")
face_crop_pipeline = FaceCrop()
face_crop_pipeline.load_model(model_path)
app = gr.Blocks()
with app:
gr.Markdown("# Face Crop Anime")
with gr.Row():
input_img = gr.Image(label="Input Image", image_mode="RGB", type='filepath')
output_img = gr.Gallery(label="Cropped Image")
with gr.Row():
crop_mode = gr.Dropdown(['Auto', 'No Scale', 'Full Screen', 'Fixed Face Propotion'], label="Crop Mode", value='Auto', type='index')
tgt_width = gr.Slider(32, 2048, value=512, label="Width")
tgt_height = gr.Slider(32, 2048, value=512, label="Height")
with gr.Row():
face_ratio = gr.Slider(1, 5, step=0.1, value=2, label="Face Ratio", info="Necessary if choosing \'Auto\' or 'Fixed Face Propotion' Mode")
threshold = gr.Slider(1, 5, step=0.1, value=1.5, label="Threshold", info="Necessary if choosing \'Auto\' Mode")
run_btn = gr.Button(variant="primary")
with gr.Row():
examples_data = [["examples/Eda.png"],["examples/Chtholly.png"],["examples/Fairies.png"]]
examples = gr.Examples(examples=examples_data,
inputs=input_img)
run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height, face_ratio, threshold], [output_img])
app.launch()