StKirill's picture
Update app.py
e3b8308
raw
history blame
2.29 kB
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
import PIL
import gradio as gr
from PIL import Image, ImageDraw
import requests
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
def biggest_obj(res):
max_area = 0
for i, bb in enumerate(res["boxes"]):
x1,y1,x2,y2 = list(map(int, bb.tolist()))
area = (abs(x2-x1)*abs(y1-y2))
if area > max_area:
max_area = area
ind = i
coords = list(map(int, bb.tolist()))
cl = model.config.id2label[res["labels"][ind].item()]
return ind, coords, cl
def create_mask(im_shape:tuple, mask_zone:list):
mask = Image.new("L", im_shape, 0)
draw = ImageDraw.Draw(mask)
draw.rectangle(mask_zone, fill=255)
return mask
from diffusers import StableDiffusionInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
).to(device)
def predict(image, prompt):
image = image.convert("RGB").resize((512, 512))
# DETR works
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# find the biggest bb on the image
ind, coords, cl = biggest_obj(results)
# mask image
mask_image = create_mask(image.size, coords)
images = pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
guidance_scale=5,
generator=torch.Generator(device="cuda").manual_seed(0),
num_images_per_prompt=1,
).images
return(images[0])
gr.Interface(
predict,
title = 'Stable Diffusion In-Painting',
inputs=[
gr.Image(type = 'pil'),
gr.Textbox(label = 'prompt')
],
outputs = [
gr.Image()
]
).launch(debug=True)