Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio_client import Client, handle_file | |
from PIL import Image | |
import tempfile | |
import requests | |
from io import BytesIO | |
import os | |
# Initialize the Hugging Face API clients | |
captioning_client = Client("fancyfeast/joy-caption-pre-alpha") | |
generation_client = Client("black-forest-labs/FLUX.1-dev") | |
# Function to caption an image | |
def caption_image(image): | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
image.save(temp_file.name) | |
caption = captioning_client.predict( | |
input_image=handle_file(temp_file.name), | |
api_name="/stream_chat" | |
) | |
return caption | |
# Function to generate an image from a text prompt using Hugging Face API | |
def generate_image_from_caption(caption): | |
image = generation_client.predict( | |
prompt=caption, | |
seed=0, | |
randomize_seed=True, | |
width=1024, | |
height=1024, | |
guidance_scale=3.5, | |
num_inference_steps=28, | |
api_name="/infer" | |
) | |
# Check if the response is a URL or a local path | |
image_url = image[0] | |
if not image_url.startswith("http"): | |
# Handle local file path | |
with open(image_url, "rb") as file: | |
return Image.open(file) | |
# Fetch image from URL | |
response = requests.get(image_url) | |
return Image.open(BytesIO(response.content)) | |
# Main function to handle the upload and generate images and captions in a loop | |
def process_image(image, iterations): | |
# Ensure iterations is an integer | |
iterations = int(round(iterations)) | |
generated_images = [] | |
captions = [] | |
current_image = image | |
for i in range(iterations): | |
# Caption the current image | |
caption = caption_image(current_image) | |
captions.append(caption) | |
# Notify that the caption has been made | |
status = f"Caption made: {caption}" | |
# Generate a new image based on the caption | |
new_image = generate_image_from_caption(caption) | |
generated_images.append(new_image) | |
# Notify that the image has been generated | |
status += f"\nImage generated for iteration {i+1}" | |
# Set the newly generated image as the current image for the next iteration | |
current_image = new_image | |
# Notify that the process is completed | |
status += "\nProcessing complete!" | |
return generated_images, captions, status | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload an Image") | |
iterations_input = gr.Number(value=3, label="Number of Iterations", precision=0) | |
with gr.Row(): | |
output_images = gr.Gallery(label="Generated Images") | |
output_captions = gr.Textbox(label="Generated Captions") | |
status_output = gr.Textbox(label="Status Updates", lines=10) | |
generate_button = gr.Button("Generate") | |
generate_button.click( | |
fn=process_image, | |
inputs=[image_input, iterations_input], | |
outputs=[output_images, output_captions, status_output] | |
) | |
# Launch the app | |
demo.launch() | |