Spaces:
Sleeping
Sleeping
import os | |
from io import BytesIO | |
import gradio as gr | |
import grpc | |
from PIL import Image | |
from cachetools import LRUCache | |
import hashlib | |
from protos.inference_pb2 import GuideAndRescaleRequest, GuideAndRescaleResponse | |
from protos.inference_pb2_grpc import GuideAndRescaleServiceStub | |
def get_bytes(img): | |
if img is None: | |
return img | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
return buffered.getvalue() | |
def bytes_to_image(image: bytes) -> Image.Image: | |
image = Image.open(BytesIO(image)) | |
return image | |
def resize(img): | |
if img.size != (512, 512): | |
img = img.resize((512, 512), Image.Resampling.LANCZOS) | |
return img | |
def edit(image, source_prompt, target_prompt, config, progress=gr.Progress(track_tqdm=True)): | |
if not image or not source_prompt or not target_prompt: | |
raise ValueError("Need to upload an image and enter init and edit prompts") | |
image_bytes = get_bytes(image) | |
os.environ['SERVER'] = "0.0.0.0:50052" | |
with grpc.insecure_channel(os.environ['SERVER']) as channel: | |
stub = GuideAndRescaleServiceStub(channel) | |
output: GuideAndRescaleResponse = stub.swap( | |
GuideAndRescaleRequest(image=image_bytes, source_prompt=source_prompt, target_prompt=target_prompt, | |
config=config, use_cache=True) | |
) | |
output = bytes_to_image(output.image) | |
return output | |
def get_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## Guide-and-Rescale") | |
gr.Markdown( | |
'<div style="display: flex; align-items: center; gap: 10px;">' | |
'<span>Official Guide-and-Rescale Gradio demo:</span>' | |
'<a href="https://github.com/AIRI-Institute/Guide-and-Rescale"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>' | |
'<a href="https://colab.research.google.com/drive/1noKOOcDBBL_m5_UqU15jBBqiM8piLZ1O?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>' | |
'</div>' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image = gr.Image(label="Image that you want to edit", type="pil") | |
with gr.Row(): | |
source_prompt = gr.Textbox(label="Init Prompt", info="Describs the content on the original image.") | |
target_prompt = gr.Textbox(label="Edit Prompt", info="Describs what is expected in the output image.") | |
config = gr.Radio(["non-stylisation", "stylisation"], value='non-stylisation', | |
label="Type of Editing", info="Selects a config for editing.") | |
with gr.Row(): | |
btn = gr.Button("Edit image") | |
with gr.Column(): | |
with gr.Row(): | |
output = gr.Image(label="Result: edited image") | |
gr.Examples(examples=[["input/1.png", 'A photo of a tiger', 'A photo of a lion', 'non-stylisation'], ["input/zebra.jpeg", 'A photo of a zebra', 'A photo of a white horse', 'non-stylisation'], ["input/13.png", 'A photo', 'Anime style face', 'stylisation']], inputs=[image, source_prompt, target_prompt, config], | |
outputs=output) | |
image.upload(fn=resize, inputs=[image], outputs=image) | |
btn.click(fn=edit, inputs=[image, source_prompt, target_prompt, config], outputs=output) | |
gr.Markdown('''To cite the paper by the authors | |
``` | |
TODO: add cite | |
``` | |
''') | |
return demo | |
if __name__ == '__main__': | |
align_cache = LRUCache(maxsize=10) | |
demo = get_demo() | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |