|
import torch
|
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
|
from diffusers import StableDiffusionInpaintPipeline
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
|
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
|
|
|
def segment_image(image):
|
|
inputs = seg_processor(image, return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
seg_outputs = seg_model(**inputs)
|
|
|
|
|
|
seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0]
|
|
|
|
|
|
segment_labels = {}
|
|
for segment in seg_prediction['segments_info']:
|
|
segment_id = segment['id']
|
|
segment_label_id = segment['label_id']
|
|
segment_label = seg_model.config.id2label[segment_label_id]
|
|
|
|
segment_labels.update({segment_id : segment_label})
|
|
|
|
return seg_prediction, segment_labels
|
|
|
|
|
|
|
|
if device == 'cuda':
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
|
"runwayml/stable-diffusion-inpainting",
|
|
torch_dtype=torch.float16,
|
|
).to(device)
|
|
else:
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
|
"runwayml/stable-diffusion-inpainting",
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto"
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3):
|
|
""" Uses Stable Diffusion model to inpaint image
|
|
Inputs:
|
|
image - input image (PIL or torch tensor)
|
|
mask - mask for inpainting same size as image (PIL or troch tensor)
|
|
W - size of image
|
|
H - size of mask
|
|
prompt - prompt for inpainting
|
|
seed - random seed
|
|
Outputs:
|
|
images - output images
|
|
"""
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
images = pipe(
|
|
prompt=prompt,
|
|
image=image,
|
|
mask_image=mask,
|
|
height=H,
|
|
width=W,
|
|
guidance_scale=guidance_scale,
|
|
generator=generator,
|
|
num_images_per_prompt=num_samples,
|
|
).images
|
|
|
|
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|