File size: 7,749 Bytes
85330fa
d2fcb60
 
 
 
 
 
 
 
 
 
 
 
2b96aad
d2fcb60
 
2b96aad
d2fcb60
 
 
5dc51ba
 
d2fcb60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf9fa34
d2fcb60
 
cf9fa34
d2fcb60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85330fa
2b96aad
d2fcb60
 
 
2b96aad
d2fcb60
2b96aad
d2fcb60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b96aad
d2fcb60
 
85330fa
d2fcb60
 
5dc51ba
d2fcb60
 
5dc51ba
d2fcb60
 
 
 
 
 
 
 
 
 
85330fa
d2fcb60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
import asyncio
import fal_client
from PIL import Image
import requests
import io
import os

# Set up your Fal API key as an environment variable
os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
fal_client.api_key = os.environ["FAL_KEY"]

# Model choices (base models)
base_model_paths = {
    "Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE",
    "Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
    "Deliberate": "Yntec/Deliberate",
    "Deliberate V2": "Yntec/Deliberate2",
    "Dreamshaper 8": "Lykon/dreamshaper-8",
    "Epic Realism": "emilianJR/epiCRealism"
}

async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
    """
    Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
    """
    try:
        handler = fal_client.submit(
            "fal-ai/ip-adapter-face-id",
            arguments={
                "model_type": model_type,
                "prompt": prompt,
                "face_image_url": image_url,
                "negative_prompt": negative_prompt,
                "seed": seed,
                "guidance_scale": guidance_scale,
                "num_inference_steps": num_inference_steps,
                "num_samples": num_samples,
                "width": width,
                "height": height,
                "base_1_5_model_repo": base_model_paths[base_model],  # Base model selected by user
                "base_sdxl_model_repo": "SG161222/RealVisXL_V3.0",  # SDXL model as default
            },
        )
        # Retrieve the result synchronously
        result = handler.get()

        if "image" in result and "url" in result["image"]:
            return result["image"]  # Return the full image information dictionary
        else:
            return None
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

def fetch_image_from_url(url: str) -> Image.Image:
    """
    Download the image from the given URL and return it as a PIL Image.
    """
    response = requests.get(url)
    return Image.open(io.BytesIO(response.content))

async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
    """
    Asynchronous function to handle image upload, prompt inputs and generate the final image.
    """
    # Upload the image and get a valid URL
    image_url = await upload_image_to_server(image)
    
    if not image_url:
        return None
    
    # Run the image generation
    image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)
    
    if image_info and "url" in image_info:
        return fetch_image_from_url(image_info["url"]), image_info  # Return both the image and the metadata
    
    return None, None

async def upload_image_to_server(image: Image.Image) -> str:
    """
    Upload an image to the fal_client and return the uploaded image URL.
    """
    # Convert PIL image to byte stream for upload
    byte_arr = io.BytesIO()
    image.save(byte_arr, format='PNG')
    byte_arr.seek(0)

    # Convert BytesIO to a file-like object that fal_client can handle
    with open("temp_image.png", "wb") as f:
        f.write(byte_arr.getvalue())

    # Upload the image using fal_client's asynchronous method
    try:
        upload_url = await fal_client.upload_file_async("temp_image.png")
        return upload_url
    except Exception as e:
        print(f"Error uploading image: {e}")
        return ""

def change_style(style):
    """
    Changes the style for 'Photorealistic' or 'Stylized' generation type.
    """
    if style == "Photorealistic":
        return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
    else:
        return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)

def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height):
    """
    Wrapper function to run asynchronous code in a synchronous environment like Gradio.
    """
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    
    # Execute the async process_inputs function
    result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height))
    if result_image:
        # Display both the image and metadata
        metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
        return result_image, metadata
    return None, "Error generating image"

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Image Generation with Fal API and Gradio")
    
    with gr.Row():
        with gr.Column():
            # Image input
            image_input = gr.Image(label="Upload Image", type="pil")
            
            # Textbox for prompt
            prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
            
            # Textbox for negative prompt
            negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
            
            # Radio buttons for model type (Photorealistic or Stylized)
            style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
            
            # Dropdown for selecting the base model
            base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4")

            # Seed input
            seed_input = gr.Number(label="Seed", value=42, precision=0)

            # Guidance scale slider
            guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)

            # Inference steps slider
            num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100)

            # Samples slider
            num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10)

            # Image dimensions sliders
            width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024)
            height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024)

            # Button to trigger image generation
            generate_button = gr.Button("Generate Image")
        
        with gr.Column():
            # Display generated image and metadata
            generated_image = gr.Image(label="Generated Image")
            metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)

    # Style change functionality
    style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width])

    # Define the interaction between inputs and output
    generate_button.click(
        fn=gradio_interface,
        inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height],
        outputs=[generated_image, metadata_output]
    )

# Launch the Gradio interface
demo.launch()