File size: 2,854 Bytes
e061e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from diffusers import StableDiffusionInpaintPipeline # , DiffusionPipeline


device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Image segmentation
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")

def segment_image(image):
    inputs = seg_processor(image, return_tensors="pt")
    
    with torch.no_grad():
        seg_outputs = seg_model(**inputs)

    # get prediction dict
    seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0]

    # get segment labels dict
    segment_labels = {}
    for segment in seg_prediction['segments_info']:
        segment_id = segment['id']
        segment_label_id = segment['label_id']
        segment_label = seg_model.config.id2label[segment_label_id]

        segment_labels.update({segment_id : segment_label})

    return seg_prediction, segment_labels

# Image inpainting pipeline
# get Stable Diffusion model for image inpainting
if device == 'cuda':
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        torch_dtype=torch.float16,
    ).to(device)
else:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        torch_dtype=torch.bfloat16, 
        device_map="auto"
    )

# pipe = StableDiffusionInpaintPipeline.from_pretrained( # DiffusionPipeline.from_pretrained(
#         "runwayml/stable-diffusion-inpainting",
#         revision="fp16",
#         torch_dtype=torch.bfloat16, 
#         # device_map="auto" # use for Hugging face spaces
#     )
# pipe.to(device) # use for local environment

def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3):
    """ Uses Stable Diffusion model to inpaint image 

        Inputs: 

            image - input image (PIL or torch tensor)

            mask - mask for inpainting same size as image (PIL or troch tensor)

            W - size of image

            H - size of mask

            prompt - prompt for inpainting

            seed - random seed 

        Outputs:

            images - output images

        """
    generator = torch.Generator(device=device).manual_seed(seed)
    images = pipe(
        prompt=prompt,
        image=image,
        mask_image=mask, # ensure mask is same type as image
        height=H,
        width=W,
        guidance_scale=guidance_scale,
        generator=generator,
        num_images_per_prompt=num_samples,
    ).images
    
    return images