tobiaspires's picture
Update app.py
b86ee7f
import asyncio
import io
from io import BytesIO
import gradio as gr
import requests
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image, ImageOps, ImageDraw, ImageFont
import PIL
import replicate
import os
# cuda cpu
device_name = 'cpu'
device = torch.device(device_name)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model_clip = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
os.environ['REPLICATE_API_TOKEN'] = '77d68d98d66117c680760543799d97177d6aa722'
model = replicate.models.get("stability-ai/stable-diffusion-inpainting")
version = model.versions.get("c28b92a7ecd66eee4aefcd8a94eb9e7f6c3805d5f06038165407fb5cb355ba67")
sf_prompt_1 = "sunflowers, old bridge, mountain, grass"
sf_neg_prompt_1 = "animal"
sf_prompt_2 = "fire, landscape"
sf_neg_prompt_2 = "animal"
template1 = Image.open("templates/template1.png").resize((512, 512))
template2 = Image.open("templates/template2.png").resize((512, 512))
fontMain = ImageFont.truetype(font="fonts/arial.ttf", size=32)
fontSecond = ImageFont.truetype(font="fonts/arial.ttf", size=18)
def numpy_to_pil(images):
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def get_mask(text, image):
inputs = processor(
text=[text], images=[image], padding="max_length", return_tensors="pt"
).to(device)
outputs = model_clip(**inputs)
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()
mask_pil = numpy_to_pil(mask)[0].resize(image.size)
return mask_pil
def image_to_byte_array(image: Image) -> bytes:
imgByteArr = io.BytesIO()
image.save(imgByteArr, format='PNG')
#imgByteArr = imgByteArr.getvalue()
return imgByteArr
def add_template(image, template):
image.paste(template, (0, 0), mask=template)
return image
async def predict(prompt, negative_prompt, image, mask_img):
image = image.convert("RGB").resize((512, 512))
mask_image = mask_img.convert("RGB").resize((512, 512))
mask_image = ImageOps.invert(mask_image)
inputs = {
# Input prompt
'prompt': prompt,
# Specify things to not see in the output
'negative_prompt': negative_prompt,
# Inital image to generate variations of. Supproting images size with
# 512x512
'image': image_to_byte_array(image),
# Black and white image to use as mask for inpainting over the image
# provided. White pixels are inpainted and black pixels are preserved
'mask': image_to_byte_array(mask_image),
# Number of images to output. Higher number of outputs may OOM.
# Range: 1 to 8
'num_outputs': 1,
# Number of denoising steps
# Range: 1 to 500
'num_inference_steps': 25,
# Scale for classifier-free guidance
# Range: 1 to 20
'guidance_scale': 7.5,
# Random seed. Leave blank to randomize the seed
# 'seed': ...,
}
output = version.predict(**inputs)
response = requests.get(output[0])
img_final = Image.open(BytesIO(response.content))
mask = mask_image.convert('L')
PIL.Image.composite(img_final, image, mask)
return (img_final)
async def predicts(sf_prompt_1, sf_neg_prompt_1, sf_prompt_2, sf_neg_prompt_2, image, image_numpy, mask_img, only_test):
if only_test:
img1 = Image.fromarray(image_numpy).convert("RGB").resize((512, 512))
img2 = Image.fromarray(image_numpy).convert("RGB").resize((512, 512))
return img1, img2
task1 = asyncio.create_task(predict(sf_prompt_1, sf_neg_prompt_1, image, mask_img))
await asyncio.sleep(5)
task2 = asyncio.create_task(predict(sf_prompt_2, sf_neg_prompt_2, image, mask_img))
await task1
await task2
img1 = task1.result()
img2 = task2.result()
return img1, img2
def draw_text(img, template_coords, main_text, second_text):
x1 = template_coords['x1']
y1 = template_coords['y1']
x2 = template_coords['x2']
y2 = template_coords['y2']
if '\\n' in main_text:
main_text = main_text.replace('\\n', '\n')
if '\\n' in second_text:
second_text = second_text.replace('\\n', '\n')
draw = ImageDraw.Draw(img)
draw.text((x1, y1), main_text, fill=(255, 255, 255), font=fontMain)
draw.text((x2, y2), second_text, fill=(255, 255, 255), font=fontSecond)
def inference(obj2mask, image_numpy, main_text, second_text, only_test):
generator = torch.Generator()
generator.manual_seed(int(52362))
image = Image.fromarray(image_numpy).convert("RGB").resize((512, 512))
mask_img = get_mask(obj2mask, image)
img1, img2 = asyncio.run(predicts(sf_prompt_1, sf_neg_prompt_1, sf_prompt_2, sf_neg_prompt_2, image, image_numpy, mask_img, only_test))
img1_1 = add_template(img1.copy(), template1.copy())
img1_2 = add_template(img1.copy(), template2.copy())
img2_1 = add_template(img2.copy(), template1.copy())
img2_2 = add_template(img2.copy(), template2.copy())
template1_coords = {
'x1': 700/2,
'y1': 630/2,
'x2': 420/2,
'y2': 800/2
}
template2_coords = {
'x1': 30/2,
'y1': 30/2,
'x2': 300/2,
'y2': 740/2
}
draw_text(img1_1, template1_coords, main_text, second_text)
draw_text(img1_2, template2_coords, main_text, second_text)
draw_text(img2_1, template1_coords, main_text, second_text)
draw_text(img2_2, template2_coords, main_text, second_text)
return [img1_1, img1_2, img2_1, img2_2]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
txt_1 = gr.Textbox(label="Texto principal da propaganda", value="Promoção\nImperdível", lines=2)
txt_2 = gr.Textbox(label="Texto secundário da propaganda", value="Até 50% para alguns produtos\nEntre em contato com um dos\nnossos vendedores", lines=3)
mask = gr.Textbox(label="Descrição da imagem", value="shoe")
intput_img = gr.Image()
run = gr.Button(value="Gerar")
chk_test = gr.Checkbox(label='Gerar Prévia')
with gr.Row():
with gr.Column():
output_img1_1 = gr.Image()
with gr.Column():
output_img1_2 = gr.Image()
with gr.Row():
with gr.Column():
output_img2_1 = gr.Image()
with gr.Column():
output_img2_2 = gr.Image()
run.click(
inference,
inputs=[mask, intput_img, txt_1, txt_2, chk_test],
outputs=[output_img1_1, output_img1_2, output_img2_1, output_img2_2],
)
demo.queue(concurrency_count=1)
demo.launch()