import gradio as gr import asyncio import fal_client from PIL import Image import requests import io import os # Set up your Fal API key as an environment variable os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4" fal_client.api_key = os.environ["FAL_KEY"] # Model choices (base models) base_model_paths = { "Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE", "Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE", "Deliberate": "Yntec/Deliberate", "Deliberate V2": "Yntec/Deliberate2", "Dreamshaper 8": "Lykon/dreamshaper-8", "Epic Realism": "emilianJR/epiCRealism" } async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int): """ Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model. """ try: handler = fal_client.submit( "fal-ai/ip-adapter-face-id", arguments={ "model_type": model_type, "prompt": prompt, "face_image_url": image_url, "negative_prompt": negative_prompt, "seed": seed, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "num_samples": num_samples, "width": width, "height": height, "base_1_5_model_repo": base_model_paths[base_model], # Base model selected by user "base_sdxl_model_repo": "SG161222/RealVisXL_V3.0", # SDXL model as default }, ) # Retrieve the result synchronously result = handler.get() if "image" in result and "url" in result["image"]: return result["image"] # Return the full image information dictionary else: return None except Exception as e: print(f"Error generating image: {e}") return None def fetch_image_from_url(url: str) -> Image.Image: """ Download the image from the given URL and return it as a PIL Image. """ response = requests.get(url) return Image.open(io.BytesIO(response.content)) async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int): """ Asynchronous function to handle image upload, prompt inputs and generate the final image. """ # Upload the image and get a valid URL image_url = await upload_image_to_server(image) if not image_url: return None # Run the image generation image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height) if image_info and "url" in image_info: return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata return None, None async def upload_image_to_server(image: Image.Image) -> str: """ Upload an image to the fal_client and return the uploaded image URL. """ # Convert PIL image to byte stream for upload byte_arr = io.BytesIO() image.save(byte_arr, format='PNG') byte_arr.seek(0) # Convert BytesIO to a file-like object that fal_client can handle with open("temp_image.png", "wb") as f: f.write(byte_arr.getvalue()) # Upload the image using fal_client's asynchronous method try: upload_url = await fal_client.upload_file_async("temp_image.png") return upload_url except Exception as e: print(f"Error uploading image: {e}") return "" def change_style(style): """ Changes the style for 'Photorealistic' or 'Stylized' generation type. """ if style == "Photorealistic": return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0) else: return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8) def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height): """ Wrapper function to run asynchronous code in a synchronous environment like Gradio. """ loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Execute the async process_inputs function result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)) if result_image: # Display both the image and metadata metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}" return result_image, metadata return None, "Error generating image" # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Image Generation with Fal API and Gradio") with gr.Row(): with gr.Column(): # Image input image_input = gr.Image(label="Upload Image", type="pil") # Textbox for prompt prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2) # Textbox for negative prompt negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2) # Radio buttons for model type (Photorealistic or Stylized) style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic") # Dropdown for selecting the base model base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4") # Seed input seed_input = gr.Number(label="Seed", value=42, precision=0) # Guidance scale slider guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20) # Inference steps slider num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100) # Samples slider num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10) # Image dimensions sliders width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024) height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024) # Button to trigger image generation generate_button = gr.Button("Generate Image") with gr.Column(): # Display generated image and metadata generated_image = gr.Image(label="Generated Image") metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6) # Style change functionality style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width]) # Define the interaction between inputs and output generate_button.click( fn=gradio_interface, inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height], outputs=[generated_image, metadata_output] ) # Launch the Gradio interface demo.launch()