StKirill commited on
Commit
e3b8308
·
1 Parent(s): d9e5573

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -4
app.py CHANGED
@@ -1,7 +1,77 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ import torch
4
+ import PIL
5
+ import gradio as gr
6
+ from PIL import Image, ImageDraw
7
+ import requests
8
+
9
+ # you can specify the revision tag if you don't want the timm dependency
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
12
+
13
+ def biggest_obj(res):
14
+ max_area = 0
15
+ for i, bb in enumerate(res["boxes"]):
16
+ x1,y1,x2,y2 = list(map(int, bb.tolist()))
17
+ area = (abs(x2-x1)*abs(y1-y2))
18
+ if area > max_area:
19
+ max_area = area
20
+ ind = i
21
+ coords = list(map(int, bb.tolist()))
22
+ cl = model.config.id2label[res["labels"][ind].item()]
23
+ return ind, coords, cl
24
+
25
+
26
+ def create_mask(im_shape:tuple, mask_zone:list):
27
+ mask = Image.new("L", im_shape, 0)
28
+ draw = ImageDraw.Draw(mask)
29
+ draw.rectangle(mask_zone, fill=255)
30
+ return mask
31
+
32
+ from diffusers import StableDiffusionInpaintPipeline
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
35
+ "runwayml/stable-diffusion-inpainting",
36
+ revision="fp16",
37
+ torch_dtype=torch.float16,
38
+ ).to(device)
39
+
40
+ def predict(image, prompt):
41
+ image = image.convert("RGB").resize((512, 512))
42
+ # DETR works
43
+ inputs = processor(images=image, return_tensors="pt")
44
+ outputs = model(**inputs)
45
+ # convert outputs (bounding boxes and class logits) to COCO API
46
+ # let's only keep detections with score > 0.9
47
+ target_sizes = torch.tensor([image.size[::-1]])
48
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
49
+
50
+ # find the biggest bb on the image
51
+ ind, coords, cl = biggest_obj(results)
52
+ # mask image
53
+ mask_image = create_mask(image.size, coords)
54
+
55
+ images = pipe(
56
+ prompt=prompt,
57
+ image=image,
58
+ mask_image=mask_image,
59
+ guidance_scale=5,
60
+ generator=torch.Generator(device="cuda").manual_seed(0),
61
+ num_images_per_prompt=1,
62
+ ).images
63
+
64
 
65
+ return(images[0])
 
66
 
67
+ gr.Interface(
68
+ predict,
69
+ title = 'Stable Diffusion In-Painting',
70
+ inputs=[
71
+ gr.Image(type = 'pil'),
72
+ gr.Textbox(label = 'prompt')
73
+ ],
74
+ outputs = [
75
+ gr.Image()
76
+ ]
77
+ ).launch(debug=True)