import gradio as gr from gradio_client import Client, handle_file from PIL import Image import tempfile import requests from io import BytesIO import os # Initialize the Hugging Face API clients captioning_client = Client("fancyfeast/joy-caption-pre-alpha") generation_client = Client("black-forest-labs/FLUX.1-dev") # Function to caption an image def caption_image(image): with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: image.save(temp_file.name) caption = captioning_client.predict( input_image=handle_file(temp_file.name), api_name="/stream_chat" ) return caption # Function to generate an image from a text prompt using Hugging Face API def generate_image_from_caption(caption): image = generation_client.predict( prompt=caption, seed=0, randomize_seed=True, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, api_name="/infer" ) # Check if the response is a URL or a local path image_url = image[0] if not image_url.startswith("http"): # Handle local file path with open(image_url, "rb") as file: return Image.open(file) # Fetch image from URL response = requests.get(image_url) return Image.open(BytesIO(response.content)) # Main function to handle the upload and generate images and captions in a loop def process_image(image, iterations): # Ensure iterations is an integer iterations = int(round(iterations)) generated_images = [] captions = [] current_image = image for i in range(iterations): # Caption the current image caption = caption_image(current_image) captions.append(caption) # Notify that the caption has been made status = f"Caption made: {caption}" # Generate a new image based on the caption new_image = generate_image_from_caption(caption) generated_images.append(new_image) # Notify that the image has been generated status += f"\nImage generated for iteration {i+1}" # Set the newly generated image as the current image for the next iteration current_image = new_image # Notify that the process is completed status += "\nProcessing complete!" return generated_images, captions, status # Gradio Interface with gr.Blocks() as demo: with gr.Row(): image_input = gr.Image(type="pil", label="Upload an Image") iterations_input = gr.Number(value=3, label="Number of Iterations", precision=0) with gr.Row(): output_images = gr.Gallery(label="Generated Images") output_captions = gr.Textbox(label="Generated Captions") status_output = gr.Textbox(label="Status Updates", lines=10) generate_button = gr.Button("Generate") generate_button.click( fn=process_image, inputs=[image_input, iterations_input], outputs=[output_images, output_captions, status_output] ) # Launch the app demo.launch()