Spaces:
Runtime error
Runtime error
import asyncio | |
import io | |
from io import BytesIO | |
import gradio as gr | |
import requests | |
import torch | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
from PIL import Image, ImageOps, ImageDraw, ImageFont | |
import PIL | |
import replicate | |
import os | |
# cuda cpu | |
device_name = 'cpu' | |
device = torch.device(device_name) | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model_clip = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) | |
os.environ['REPLICATE_API_TOKEN'] = '77d68d98d66117c680760543799d97177d6aa722' | |
model = replicate.models.get("stability-ai/stable-diffusion-inpainting") | |
version = model.versions.get("c28b92a7ecd66eee4aefcd8a94eb9e7f6c3805d5f06038165407fb5cb355ba67") | |
sf_prompt_1 = "sunflowers, old bridge, mountain, grass" | |
sf_neg_prompt_1 = "animal" | |
sf_prompt_2 = "fire, landscape" | |
sf_neg_prompt_2 = "animal" | |
template1 = Image.open("templates/template1.png").resize((512, 512)) | |
template2 = Image.open("templates/template2.png").resize((512, 512)) | |
fontMain = ImageFont.truetype(font="fonts/arial.ttf", size=32) | |
fontSecond = ImageFont.truetype(font="fonts/arial.ttf", size=18) | |
def numpy_to_pil(images): | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def get_mask(text, image): | |
inputs = processor( | |
text=[text], images=[image], padding="max_length", return_tensors="pt" | |
).to(device) | |
outputs = model_clip(**inputs) | |
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() | |
mask_pil = numpy_to_pil(mask)[0].resize(image.size) | |
return mask_pil | |
def image_to_byte_array(image: Image) -> bytes: | |
imgByteArr = io.BytesIO() | |
image.save(imgByteArr, format='PNG') | |
#imgByteArr = imgByteArr.getvalue() | |
return imgByteArr | |
def add_template(image, template): | |
image.paste(template, (0, 0), mask=template) | |
return image | |
async def predict(prompt, negative_prompt, image, mask_img): | |
image = image.convert("RGB").resize((512, 512)) | |
mask_image = mask_img.convert("RGB").resize((512, 512)) | |
mask_image = ImageOps.invert(mask_image) | |
inputs = { | |
# Input prompt | |
'prompt': prompt, | |
# Specify things to not see in the output | |
'negative_prompt': negative_prompt, | |
# Inital image to generate variations of. Supproting images size with | |
# 512x512 | |
'image': image_to_byte_array(image), | |
# Black and white image to use as mask for inpainting over the image | |
# provided. White pixels are inpainted and black pixels are preserved | |
'mask': image_to_byte_array(mask_image), | |
# Number of images to output. Higher number of outputs may OOM. | |
# Range: 1 to 8 | |
'num_outputs': 1, | |
# Number of denoising steps | |
# Range: 1 to 500 | |
'num_inference_steps': 25, | |
# Scale for classifier-free guidance | |
# Range: 1 to 20 | |
'guidance_scale': 7.5, | |
# Random seed. Leave blank to randomize the seed | |
# 'seed': ..., | |
} | |
output = version.predict(**inputs) | |
response = requests.get(output[0]) | |
img_final = Image.open(BytesIO(response.content)) | |
mask = mask_image.convert('L') | |
PIL.Image.composite(img_final, image, mask) | |
return (img_final) | |
async def predicts(sf_prompt_1, sf_neg_prompt_1, sf_prompt_2, sf_neg_prompt_2, image, image_numpy, mask_img, only_test): | |
if only_test: | |
img1 = Image.fromarray(image_numpy).convert("RGB").resize((512, 512)) | |
img2 = Image.fromarray(image_numpy).convert("RGB").resize((512, 512)) | |
return img1, img2 | |
task1 = asyncio.create_task(predict(sf_prompt_1, sf_neg_prompt_1, image, mask_img)) | |
await asyncio.sleep(5) | |
task2 = asyncio.create_task(predict(sf_prompt_2, sf_neg_prompt_2, image, mask_img)) | |
await task1 | |
await task2 | |
img1 = task1.result() | |
img2 = task2.result() | |
return img1, img2 | |
def draw_text(img, template_coords, main_text, second_text): | |
x1 = template_coords['x1'] | |
y1 = template_coords['y1'] | |
x2 = template_coords['x2'] | |
y2 = template_coords['y2'] | |
if '\\n' in main_text: | |
main_text = main_text.replace('\\n', '\n') | |
if '\\n' in second_text: | |
second_text = second_text.replace('\\n', '\n') | |
draw = ImageDraw.Draw(img) | |
draw.text((x1, y1), main_text, fill=(255, 255, 255), font=fontMain) | |
draw.text((x2, y2), second_text, fill=(255, 255, 255), font=fontSecond) | |
def inference(obj2mask, image_numpy, main_text, second_text, only_test): | |
generator = torch.Generator() | |
generator.manual_seed(int(52362)) | |
image = Image.fromarray(image_numpy).convert("RGB").resize((512, 512)) | |
mask_img = get_mask(obj2mask, image) | |
img1, img2 = asyncio.run(predicts(sf_prompt_1, sf_neg_prompt_1, sf_prompt_2, sf_neg_prompt_2, image, image_numpy, mask_img, only_test)) | |
img1_1 = add_template(img1.copy(), template1.copy()) | |
img1_2 = add_template(img1.copy(), template2.copy()) | |
img2_1 = add_template(img2.copy(), template1.copy()) | |
img2_2 = add_template(img2.copy(), template2.copy()) | |
template1_coords = { | |
'x1': 700/2, | |
'y1': 630/2, | |
'x2': 420/2, | |
'y2': 800/2 | |
} | |
template2_coords = { | |
'x1': 30/2, | |
'y1': 30/2, | |
'x2': 300/2, | |
'y2': 740/2 | |
} | |
draw_text(img1_1, template1_coords, main_text, second_text) | |
draw_text(img1_2, template2_coords, main_text, second_text) | |
draw_text(img2_1, template1_coords, main_text, second_text) | |
draw_text(img2_2, template2_coords, main_text, second_text) | |
return [img1_1, img1_2, img2_1, img2_2] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
txt_1 = gr.Textbox(label="Texto principal da propaganda", value="Promoção\nImperdível", lines=2) | |
txt_2 = gr.Textbox(label="Texto secundário da propaganda", value="Até 50% para alguns produtos\nEntre em contato com um dos\nnossos vendedores", lines=3) | |
mask = gr.Textbox(label="Descrição da imagem", value="shoe") | |
intput_img = gr.Image() | |
run = gr.Button(value="Gerar") | |
chk_test = gr.Checkbox(label='Gerar Prévia') | |
with gr.Row(): | |
with gr.Column(): | |
output_img1_1 = gr.Image() | |
with gr.Column(): | |
output_img1_2 = gr.Image() | |
with gr.Row(): | |
with gr.Column(): | |
output_img2_1 = gr.Image() | |
with gr.Column(): | |
output_img2_2 = gr.Image() | |
run.click( | |
inference, | |
inputs=[mask, intput_img, txt_1, txt_2, chk_test], | |
outputs=[output_img1_1, output_img1_2, output_img2_1, output_img2_2], | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch() | |