Deadmon's picture
Update app.py
d2fcb60 verified
raw
history blame
7.75 kB
import gradio as gr
import asyncio
import fal_client
from PIL import Image
import requests
import io
import os
# Set up your Fal API key as an environment variable
os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
fal_client.api_key = os.environ["FAL_KEY"]
# Model choices (base models)
base_model_paths = {
"Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE",
"Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
"Deliberate": "Yntec/Deliberate",
"Deliberate V2": "Yntec/Deliberate2",
"Dreamshaper 8": "Lykon/dreamshaper-8",
"Epic Realism": "emilianJR/epiCRealism"
}
async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
"""
Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
"""
try:
handler = fal_client.submit(
"fal-ai/ip-adapter-face-id",
arguments={
"model_type": model_type,
"prompt": prompt,
"face_image_url": image_url,
"negative_prompt": negative_prompt,
"seed": seed,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"num_samples": num_samples,
"width": width,
"height": height,
"base_1_5_model_repo": base_model_paths[base_model], # Base model selected by user
"base_sdxl_model_repo": "SG161222/RealVisXL_V3.0", # SDXL model as default
},
)
# Retrieve the result synchronously
result = handler.get()
if "image" in result and "url" in result["image"]:
return result["image"] # Return the full image information dictionary
else:
return None
except Exception as e:
print(f"Error generating image: {e}")
return None
def fetch_image_from_url(url: str) -> Image.Image:
"""
Download the image from the given URL and return it as a PIL Image.
"""
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
"""
Asynchronous function to handle image upload, prompt inputs and generate the final image.
"""
# Upload the image and get a valid URL
image_url = await upload_image_to_server(image)
if not image_url:
return None
# Run the image generation
image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)
if image_info and "url" in image_info:
return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata
return None, None
async def upload_image_to_server(image: Image.Image) -> str:
"""
Upload an image to the fal_client and return the uploaded image URL.
"""
# Convert PIL image to byte stream for upload
byte_arr = io.BytesIO()
image.save(byte_arr, format='PNG')
byte_arr.seek(0)
# Convert BytesIO to a file-like object that fal_client can handle
with open("temp_image.png", "wb") as f:
f.write(byte_arr.getvalue())
# Upload the image using fal_client's asynchronous method
try:
upload_url = await fal_client.upload_file_async("temp_image.png")
return upload_url
except Exception as e:
print(f"Error uploading image: {e}")
return ""
def change_style(style):
"""
Changes the style for 'Photorealistic' or 'Stylized' generation type.
"""
if style == "Photorealistic":
return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
else:
return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height):
"""
Wrapper function to run asynchronous code in a synchronous environment like Gradio.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Execute the async process_inputs function
result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height))
if result_image:
# Display both the image and metadata
metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
return result_image, metadata
return None, "Error generating image"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Image Generation with Fal API and Gradio")
with gr.Row():
with gr.Column():
# Image input
image_input = gr.Image(label="Upload Image", type="pil")
# Textbox for prompt
prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
# Textbox for negative prompt
negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
# Radio buttons for model type (Photorealistic or Stylized)
style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
# Dropdown for selecting the base model
base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4")
# Seed input
seed_input = gr.Number(label="Seed", value=42, precision=0)
# Guidance scale slider
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)
# Inference steps slider
num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100)
# Samples slider
num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10)
# Image dimensions sliders
width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024)
height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024)
# Button to trigger image generation
generate_button = gr.Button("Generate Image")
with gr.Column():
# Display generated image and metadata
generated_image = gr.Image(label="Generated Image")
metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)
# Style change functionality
style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width])
# Define the interaction between inputs and output
generate_button.click(
fn=gradio_interface,
inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height],
outputs=[generated_image, metadata_output]
)
# Launch the Gradio interface
demo.launch()