Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import asyncio | |
import fal_client | |
from PIL import Image | |
import requests | |
import io | |
import os | |
# Set up your Fal API key as an environment variable | |
os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4" | |
fal_client.api_key = os.environ["FAL_KEY"] | |
# Model choices (base models) | |
base_model_paths = { | |
"Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE", | |
"Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE", | |
"Deliberate": "Yntec/Deliberate", | |
"Deliberate V2": "Yntec/Deliberate2", | |
"Dreamshaper 8": "Lykon/dreamshaper-8", | |
"Epic Realism": "emilianJR/epiCRealism" | |
} | |
async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int): | |
""" | |
Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model. | |
""" | |
try: | |
handler = fal_client.submit( | |
"fal-ai/ip-adapter-face-id", | |
arguments={ | |
"model_type": model_type, | |
"prompt": prompt, | |
"face_image_url": image_url, | |
"negative_prompt": negative_prompt, | |
"seed": seed, | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": num_inference_steps, | |
"num_samples": num_samples, | |
"width": width, | |
"height": height, | |
"base_1_5_model_repo": base_model_paths[base_model], # Base model selected by user | |
"base_sdxl_model_repo": "SG161222/RealVisXL_V3.0", # SDXL model as default | |
}, | |
) | |
# Retrieve the result synchronously | |
result = handler.get() | |
if "image" in result and "url" in result["image"]: | |
return result["image"] # Return the full image information dictionary | |
else: | |
return None | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
return None | |
def fetch_image_from_url(url: str) -> Image.Image: | |
""" | |
Download the image from the given URL and return it as a PIL Image. | |
""" | |
response = requests.get(url) | |
return Image.open(io.BytesIO(response.content)) | |
async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int): | |
""" | |
Asynchronous function to handle image upload, prompt inputs and generate the final image. | |
""" | |
# Upload the image and get a valid URL | |
image_url = await upload_image_to_server(image) | |
if not image_url: | |
return None | |
# Run the image generation | |
image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height) | |
if image_info and "url" in image_info: | |
return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata | |
return None, None | |
async def upload_image_to_server(image: Image.Image) -> str: | |
""" | |
Upload an image to the fal_client and return the uploaded image URL. | |
""" | |
# Convert PIL image to byte stream for upload | |
byte_arr = io.BytesIO() | |
image.save(byte_arr, format='PNG') | |
byte_arr.seek(0) | |
# Convert BytesIO to a file-like object that fal_client can handle | |
with open("temp_image.png", "wb") as f: | |
f.write(byte_arr.getvalue()) | |
# Upload the image using fal_client's asynchronous method | |
try: | |
upload_url = await fal_client.upload_file_async("temp_image.png") | |
return upload_url | |
except Exception as e: | |
print(f"Error uploading image: {e}") | |
return "" | |
def change_style(style): | |
""" | |
Changes the style for 'Photorealistic' or 'Stylized' generation type. | |
""" | |
if style == "Photorealistic": | |
return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0) | |
else: | |
return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8) | |
def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height): | |
""" | |
Wrapper function to run asynchronous code in a synchronous environment like Gradio. | |
""" | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
# Execute the async process_inputs function | |
result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)) | |
if result_image: | |
# Display both the image and metadata | |
metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}" | |
return result_image, metadata | |
return None, "Error generating image" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Image Generation with Fal API and Gradio") | |
with gr.Row(): | |
with gr.Column(): | |
# Image input | |
image_input = gr.Image(label="Upload Image", type="pil") | |
# Textbox for prompt | |
prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2) | |
# Textbox for negative prompt | |
negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2) | |
# Radio buttons for model type (Photorealistic or Stylized) | |
style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic") | |
# Dropdown for selecting the base model | |
base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4") | |
# Seed input | |
seed_input = gr.Number(label="Seed", value=42, precision=0) | |
# Guidance scale slider | |
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20) | |
# Inference steps slider | |
num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100) | |
# Samples slider | |
num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10) | |
# Image dimensions sliders | |
width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024) | |
height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024) | |
# Button to trigger image generation | |
generate_button = gr.Button("Generate Image") | |
with gr.Column(): | |
# Display generated image and metadata | |
generated_image = gr.Image(label="Generated Image") | |
metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6) | |
# Style change functionality | |
style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width]) | |
# Define the interaction between inputs and output | |
generate_button.click( | |
fn=gradio_interface, | |
inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height], | |
outputs=[generated_image, metadata_output] | |
) | |
# Launch the Gradio interface | |
demo.launch() | |