Spaces:
Runtime error
Runtime error
import io | |
from io import BytesIO | |
import gradio as gr | |
import requests | |
import torch | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image, ImageOps | |
import PIL | |
import replicate | |
import os | |
# cuda cpu | |
device_name = 'cpu' | |
device = torch.device(device_name) | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model_clip = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) | |
os.environ['REPLICATE_API_TOKEN'] = '16ea7157b65a155892e29298b6ddac479a12e819' | |
model_name = 'cjwbw/stable-diffusion-v2-inpainting' | |
model = replicate.models.get(model_name) | |
version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec") | |
def numpy_to_pil(images): | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def get_mask(text, image): | |
inputs = processor( | |
text=[text], images=[image], padding="max_length", return_tensors="pt" | |
).to(device) | |
outputs = model_clip(**inputs) | |
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() | |
mask_pil = numpy_to_pil(mask)[0].resize(image.size) | |
#mask_pil.show() | |
return mask_pil | |
def image_to_byte_array(image: Image) -> bytes: | |
# BytesIO is a file-like buffer stored in memory | |
imgByteArr = io.BytesIO() | |
# image.save expects a file-like as a argument | |
image.save(imgByteArr, format='PNG') | |
# Turn the BytesIO object back into a bytes object | |
#imgByteArr = imgByteArr.getvalue() | |
return imgByteArr | |
def predict(prompt, negative_prompt, image, obj2mask): | |
mask = get_mask(obj2mask, image) | |
image = image.convert("RGB").resize((512, 512)) | |
mask_image = mask.convert("RGB").resize((512, 512)) | |
mask_image = ImageOps.invert(mask_image) | |
# open("/home/tobias/WorkspageBE/replicate/tenis.png", "rb") | |
# io.BufferedReader(image_to_byte_array(image)) | |
inputs = { | |
# Input prompt | |
'prompt': prompt, | |
# Inital image to generate variations of. Supproting images size with | |
# 512x512 | |
'image': image_to_byte_array(image), | |
# Black and white image to use as mask for inpainting over the image | |
# provided. Black pixels are inpainted and white pixels are preserved | |
'mask': image_to_byte_array(mask_image), | |
# Prompt strength when using init image. 1.0 corresponds to full | |
# destruction of information in init image | |
'prompt_strength': 0.8, | |
# Number of images to output. Higher number of outputs may OOM. | |
# Range: 1 to 8 | |
'num_outputs': 1, | |
# Number of denoising steps | |
# Range: 1 to 500 | |
'num_inference_steps': 50, | |
# Scale for classifier-free guidance | |
# Range: 1 to 20 | |
'guidance_scale': 7.5, | |
# Random seed. Leave blank to randomize the seed | |
# 'seed': ..., | |
} | |
output = version.predict(**inputs) | |
response = requests.get(output[0]) | |
img_final = Image.open(BytesIO(response.content)) | |
mask = mask_image.convert('L') | |
PIL.Image.composite(img_final, image, mask) | |
return (img_final) | |
def inference(prompt, negative_prompt, obj2mask, image_numpy): | |
generator = torch.Generator() | |
generator.manual_seed(int(52362)) | |
image = numpy_to_pil(image_numpy)[0].convert("RGB").resize((512, 512)) | |
img = predict(prompt, negative_prompt, image, obj2mask) | |
return img | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="cinematic, advertisement, sharpe focus, ad, ads") | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="text, written") | |
mask = gr.Textbox(label="Mask", value="shoe") | |
intput_img = gr.Image() | |
run = gr.Button(value="Generate") | |
with gr.Column(): | |
output_img = gr.Image() | |
run.click( | |
inference, | |
inputs=[prompt, negative_prompt, mask, intput_img | |
], | |
outputs=output_img, | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch() | |