File size: 3,143 Bytes
053bf25
cfb8c8a
053bf25
12c9a39
cfb8c8a
 
cfb5c17
053bf25
 
 
 
 
 
 
12c9a39
 
 
cfb8c8a
12c9a39
 
053bf25
 
 
 
 
 
 
 
 
 
 
 
 
 
cfb5c17
 
 
 
 
 
 
 
 
cfb8c8a
 
053bf25
 
 
cfb5c17
 
 
053bf25
 
 
 
 
 
 
 
 
 
cfb5c17
 
 
053bf25
 
 
 
cfb5c17
 
 
053bf25
cfb8c8a
053bf25
cfb5c17
 
 
 
053bf25
 
 
 
 
cfb5c17
053bf25
 
 
 
cfb5c17
053bf25
 
 
 
 
 
cfb5c17
053bf25
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from gradio_client import Client, handle_file
from PIL import Image
import tempfile
import requests
from io import BytesIO
import os

# Initialize the Hugging Face API clients
captioning_client = Client("fancyfeast/joy-caption-pre-alpha")
generation_client = Client("black-forest-labs/FLUX.1-dev")

# Function to caption an image
def caption_image(image):
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
        image.save(temp_file.name)
        caption = captioning_client.predict(
            input_image=handle_file(temp_file.name),
            api_name="/stream_chat"
        )
    return caption

# Function to generate an image from a text prompt using Hugging Face API
def generate_image_from_caption(caption):
    image = generation_client.predict(
        prompt=caption,
        seed=0,
        randomize_seed=True,
        width=1024,
        height=1024,
        guidance_scale=3.5,
        num_inference_steps=28,
        api_name="/infer"
    )
    
    # Check if the response is a URL or a local path
    image_url = image[0]
    if not image_url.startswith("http"):
        # Handle local file path
        with open(image_url, "rb") as file:
            return Image.open(file)
    
    # Fetch image from URL
    response = requests.get(image_url)
    return Image.open(BytesIO(response.content))

# Main function to handle the upload and generate images and captions in a loop
def process_image(image, iterations):
    # Ensure iterations is an integer
    iterations = int(round(iterations))
    
    generated_images = []
    captions = []
    
    current_image = image
    
    for i in range(iterations):
        # Caption the current image
        caption = caption_image(current_image)
        captions.append(caption)
        
        # Notify that the caption has been made
        status = f"Caption made: {caption}"
        
        # Generate a new image based on the caption
        new_image = generate_image_from_caption(caption)
        generated_images.append(new_image)
        
        # Notify that the image has been generated
        status += f"\nImage generated for iteration {i+1}"
        
        # Set the newly generated image as the current image for the next iteration
        current_image = new_image
    
    # Notify that the process is completed
    status += "\nProcessing complete!"
    
    return generated_images, captions, status

# Gradio Interface
with gr.Blocks() as demo:
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload an Image")
        iterations_input = gr.Number(value=3, label="Number of Iterations", precision=0)
    
    with gr.Row():
        output_images = gr.Gallery(label="Generated Images")
        output_captions = gr.Textbox(label="Generated Captions")
        status_output = gr.Textbox(label="Status Updates", lines=10)
    
    generate_button = gr.Button("Generate")
    
    generate_button.click(
        fn=process_image,
        inputs=[image_input, iterations_input],
        outputs=[output_images, output_captions, status_output]
    )

# Launch the app
demo.launch()