Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,749 Bytes
85330fa d2fcb60 2b96aad d2fcb60 2b96aad d2fcb60 5dc51ba d2fcb60 cf9fa34 d2fcb60 cf9fa34 d2fcb60 85330fa 2b96aad d2fcb60 2b96aad d2fcb60 2b96aad d2fcb60 2b96aad d2fcb60 85330fa d2fcb60 5dc51ba d2fcb60 5dc51ba d2fcb60 85330fa d2fcb60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
import asyncio
import fal_client
from PIL import Image
import requests
import io
import os
# Set up your Fal API key as an environment variable
os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
fal_client.api_key = os.environ["FAL_KEY"]
# Model choices (base models)
base_model_paths = {
"Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE",
"Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
"Deliberate": "Yntec/Deliberate",
"Deliberate V2": "Yntec/Deliberate2",
"Dreamshaper 8": "Lykon/dreamshaper-8",
"Epic Realism": "emilianJR/epiCRealism"
}
async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
"""
Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
"""
try:
handler = fal_client.submit(
"fal-ai/ip-adapter-face-id",
arguments={
"model_type": model_type,
"prompt": prompt,
"face_image_url": image_url,
"negative_prompt": negative_prompt,
"seed": seed,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"num_samples": num_samples,
"width": width,
"height": height,
"base_1_5_model_repo": base_model_paths[base_model], # Base model selected by user
"base_sdxl_model_repo": "SG161222/RealVisXL_V3.0", # SDXL model as default
},
)
# Retrieve the result synchronously
result = handler.get()
if "image" in result and "url" in result["image"]:
return result["image"] # Return the full image information dictionary
else:
return None
except Exception as e:
print(f"Error generating image: {e}")
return None
def fetch_image_from_url(url: str) -> Image.Image:
"""
Download the image from the given URL and return it as a PIL Image.
"""
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
"""
Asynchronous function to handle image upload, prompt inputs and generate the final image.
"""
# Upload the image and get a valid URL
image_url = await upload_image_to_server(image)
if not image_url:
return None
# Run the image generation
image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)
if image_info and "url" in image_info:
return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata
return None, None
async def upload_image_to_server(image: Image.Image) -> str:
"""
Upload an image to the fal_client and return the uploaded image URL.
"""
# Convert PIL image to byte stream for upload
byte_arr = io.BytesIO()
image.save(byte_arr, format='PNG')
byte_arr.seek(0)
# Convert BytesIO to a file-like object that fal_client can handle
with open("temp_image.png", "wb") as f:
f.write(byte_arr.getvalue())
# Upload the image using fal_client's asynchronous method
try:
upload_url = await fal_client.upload_file_async("temp_image.png")
return upload_url
except Exception as e:
print(f"Error uploading image: {e}")
return ""
def change_style(style):
"""
Changes the style for 'Photorealistic' or 'Stylized' generation type.
"""
if style == "Photorealistic":
return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
else:
return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height):
"""
Wrapper function to run asynchronous code in a synchronous environment like Gradio.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Execute the async process_inputs function
result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height))
if result_image:
# Display both the image and metadata
metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
return result_image, metadata
return None, "Error generating image"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Image Generation with Fal API and Gradio")
with gr.Row():
with gr.Column():
# Image input
image_input = gr.Image(label="Upload Image", type="pil")
# Textbox for prompt
prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
# Textbox for negative prompt
negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
# Radio buttons for model type (Photorealistic or Stylized)
style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
# Dropdown for selecting the base model
base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4")
# Seed input
seed_input = gr.Number(label="Seed", value=42, precision=0)
# Guidance scale slider
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)
# Inference steps slider
num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100)
# Samples slider
num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10)
# Image dimensions sliders
width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024)
height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024)
# Button to trigger image generation
generate_button = gr.Button("Generate Image")
with gr.Column():
# Display generated image and metadata
generated_image = gr.Image(label="Generated Image")
metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)
# Style change functionality
style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width])
# Define the interaction between inputs and output
generate_button.click(
fn=gradio_interface,
inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height],
outputs=[generated_image, metadata_output]
)
# Launch the Gradio interface
demo.launch()
|