File size: 3,378 Bytes
c079997
583b7ad
c079997
 
 
 
 
 
 
ad3aed5
c079997
 
 
 
 
ad3aed5
c079997
 
 
 
 
 
 
 
ad3aed5
c079997
 
 
ad3aed5
c079997
 
 
583b7ad
c079997
 
 
 
 
 
 
 
 
 
 
 
 
ebf25c1
 
 
 
c079997
 
 
 
 
 
 
 
 
 
 
ebf25c1
c079997
ebf25c1
 
 
 
c079997
 
 
 
 
 
ebf25c1
 
 
 
c079997
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf25c1
 
 
 
 
 
 
24c3cc8
 
 
 
 
 
 
 
 
c079997
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps, EditImageResult
import os
from src.hopter.client import Hopter, Environment
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from pydantic_ai.messages import (
    ToolReturnPart
)
from src.utils import upload_image
load_dotenv()

async def process_edit(image, instruction):
    hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
    mask_service = GenerateMaskService(hopter=hopter)
    image_url = upload_image(image)
    messages = [
        {
            "type": "text",
            "text": instruction
        },
    ]
    if image:
        messages.append(
            {"type": "image_url", "image_url": {"url": image_url}}
        )
    deps = ImageEditDeps(
        edit_instruction=instruction,
        image_url=image_url,
        hopter_client=hopter,
        mask_service=mask_service
    )
    result = await image_edit_agent.run(
        messages,
        deps=deps
    )
    # Extract the edited image URL from the tool return
    for message in result.new_messages():
        for part in message.parts:
            if isinstance(part, ToolReturnPart) and isinstance(part.content, EditImageResult):
                return part.content.edited_image_url
    return None

async def use_edited_image(edited_image):
    return edited_image

def clear_instruction():
    # Only clear the instruction text.
    return ""

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# PicEdit")
    gr.Markdown("""
    Welcome to PicEdit - an AI-powered image editing tool. 
    Simply upload an image and describe the changes you want to make in natural language.
    """)
    
    with gr.Row():
        # Input image on the left
        input_image = gr.Image(label="Original Image", type="filepath")
        
        with gr.Column():
            # Output image on the right
            output_image = gr.Image(label="Edited Image", type="filepath", interactive=False, scale=3)
            use_edited_btn = gr.Button("πŸ‘ˆ Use Edited Image πŸ‘ˆ")

    # Text input for editing instructions
    instruction = gr.Textbox(
        label="Editing Instructions",
        placeholder="Describe the changes you want to make to the image..."
    )
    
    # Clear button
    with gr.Row():
        clear_btn = gr.Button("Clear")
        submit_btn = gr.Button("Apply Edit", variant="primary")
    
    # Set up the event handlers
    submit_btn.click(
        fn=process_edit,
        inputs=[input_image, instruction],
        outputs=output_image
    )
    
    use_edited_btn.click(
        fn=use_edited_image,
        inputs=[output_image],
        outputs=[input_image]
    )

    # Bind the clear button's click event to only clear the instruction textbox.
    clear_btn.click(
        fn=clear_instruction,
        inputs=[],
        outputs=[instruction]
    )

    examples = gr.Examples(
        examples=[
            ["https://i.ibb.co/qYwhcc6j/c837c212afbf.jpg", "remove the pole"],
            ["https://i.ibb.co/2Mrxztw/image.png", "replace the cat with a dog"],
            ["https://i.ibb.co/9mT4cvnt/resized-78-B40-C09-1037-4-DD3-9-F48-D73637-EE4-E51.png", "ENHANCE!"]
        ],
        inputs=[input_image, instruction]
    )

if __name__ == "__main__":
    demo.launch()