File size: 4,422 Bytes
ca1f90f
 
 
884a837
ca1f90f
884a837
 
 
 
 
ca1f90f
 
884a837
 
3b14a4f
884a837
 
 
ca1f90f
 
 
 
 
 
 
884a837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1f90f
884a837
 
 
 
 
 
 
ca1f90f
 
 
 
 
 
 
 
 
 
884a837
 
 
 
 
ca1f90f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884a837
 
ca1f90f
 
884a837
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1f90f
 
884a837
 
 
 
 
 
 
 
 
 
 
 
 
 
3b14a4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import io
from io import BytesIO

import gradio as gr
import requests
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageOps
import PIL
import replicate
import os

# cuda cpu
device_name = 'cpu'
device = torch.device(device_name)

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model_clip = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)


os.environ['REPLICATE_API_TOKEN'] = '16ea7157b65a155892e29298b6ddac479a12e819'
model_name = 'cjwbw/stable-diffusion-v2-inpainting'
model = replicate.models.get(model_name)
version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec")


def numpy_to_pil(images):
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")

    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def get_mask(text, image):
    inputs = processor(
        text=[text], images=[image], padding="max_length", return_tensors="pt"
    ).to(device)

    outputs = model_clip(**inputs)
    mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()

    mask_pil = numpy_to_pil(mask)[0].resize(image.size)
    #mask_pil.show()
    return mask_pil


def image_to_byte_array(image: Image) -> bytes:
  # BytesIO is a file-like buffer stored in memory
  imgByteArr = io.BytesIO()
  # image.save expects a file-like as a argument
  image.save(imgByteArr, format='PNG')
  # Turn the BytesIO object back into a bytes object
  #imgByteArr = imgByteArr.getvalue()
  return imgByteArr


def predict(prompt, negative_prompt, image, obj2mask):
    mask = get_mask(obj2mask, image)
    image = image.convert("RGB").resize((512, 512))
    mask_image = mask.convert("RGB").resize((512, 512))
    mask_image = ImageOps.invert(mask_image)
    #  open("/home/tobias/WorkspageBE/replicate/tenis.png", "rb")
    # io.BufferedReader(image_to_byte_array(image))
    inputs = {
        # Input prompt
        'prompt': prompt,

        # Inital image to generate variations of. Supproting images size with
        # 512x512
        'image': image_to_byte_array(image),

        # Black and white image to use as mask for inpainting over the image
        # provided. Black pixels are inpainted and white pixels are preserved
        'mask': image_to_byte_array(mask_image),

        # Prompt strength when using init image. 1.0 corresponds to full
        # destruction of information in init image
        'prompt_strength': 0.8,

        # Number of images to output. Higher number of outputs may OOM.
        # Range: 1 to 8
        'num_outputs': 1,

        # Number of denoising steps
        # Range: 1 to 500
        'num_inference_steps': 50,

        # Scale for classifier-free guidance
        # Range: 1 to 20
        'guidance_scale': 7.5,

        # Random seed. Leave blank to randomize the seed
        # 'seed': ...,
    }

    output = version.predict(**inputs)

    response = requests.get(output[0])
    img_final = Image.open(BytesIO(response.content))

    mask = mask_image.convert('L')

    PIL.Image.composite(img_final, image, mask)
    return (img_final)


def inference(prompt, negative_prompt, obj2mask, image_numpy):
    generator = torch.Generator()
    generator.manual_seed(int(52362))

    image = numpy_to_pil(image_numpy)[0].convert("RGB").resize((512, 512))
    img = predict(prompt, negative_prompt, image, obj2mask)
    return img


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="cinematic, advertisement, sharpe focus, ad, ads")
            negative_prompt = gr.Textbox(label="Negative Prompt", value="text, written")
            mask = gr.Textbox(label="Mask", value="shoe")
            intput_img = gr.Image()
            run = gr.Button(value="Generate")
        with gr.Column():
            output_img = gr.Image()

    run.click(
        inference,
        inputs=[prompt, negative_prompt, mask, intput_img
        ],
        outputs=output_img,
    )

demo.queue(concurrency_count=1)
demo.launch()