File size: 4,526 Bytes
38d647f
1eb15ec
7d9bafe
cb15469
d94c56e
38041e2
97d4592
 
cb15469
 
 
97d4592
 
 
 
 
 
 
 
 
 
 
 
38041e2
d94c56e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb15ec
cb15469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb15ec
b045883
 
cb15469
97d4592
cb15469
97d4592
90c9d43
d94c56e
 
 
066f283
 
b045883
d94c56e
cb15469
 
b045883
90c9d43
 
b33ec08
 
 
90c9d43
 
b045883
cb15469
 
 
 
 
 
 
 
 
 
 
9dff307
 
38041e2
d94c56e
 
b045883
d94c56e
 
 
 
90c9d43
d94c56e
 
 
90c9d43
 
b045883
90c9d43
 
9dff307
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import numpy as np
from io import BytesIO
from PIL import Image, ImageOps
import zipfile
import os
import atexit
import shutil
import cv2
import imageio
import torchvision.transforms.functional as TF

# Create a persistent directory to store generated files
GENERATED_FILES_DIR = "generated_files"
if not os.path.exists(GENERATED_FILES_DIR):
    os.makedirs(GENERATED_FILES_DIR)

def cleanup_generated_files():
    if os.path.exists(GENERATED_FILES_DIR):
        shutil.rmtree(GENERATED_FILES_DIR)

# Register the cleanup function to run when the script exits
atexit.register(cleanup_generated_files)

def split_image_grid(image, grid_cols, grid_rows):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    width, height = image.width, image.height
    cell_width = width // grid_cols
    cell_height = height // grid_rows
    frames = []
    for i in range(grid_rows):
        for j in range(grid_cols):
            left = j * cell_width
            upper = i * cell_height
            right = left + cell_width
            lower = upper + cell_height
            frame = image.crop((left, upper, right, lower))
            frames.append(np.array(frame))
    return frames

def interpolate_frames(frames, factor=2):
    interpolated_frames = []
    for i in range(len(frames) - 1):
        frame1 = frames[i]
        frame2 = frames[i + 1]
        interpolated_frames.append(frame1)
        for j in range(1, factor):
            t = j / factor
            frame_t = cv2.addWeighted(frame1, 1 - t, frame2, t, 0)
            interpolated_frames.append(frame_t)
    interpolated_frames.append(frames[-1])
    return interpolated_frames

def enhance_gif(images):
    enhanced_images = []
    for img in images:
        img = ImageOps.autocontrast(Image.fromarray(img))
        img = img.convert("RGB")  # Ensure the image is in RGB mode
        enhanced_images.append(np.array(img))
    return enhanced_images

def create_gif_imageio(images, fps=10, loop=0):
    duration = 1000 / fps  # Convert FPS to milliseconds
    gif_path = os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif")
    images_pil = [Image.fromarray(img) for img in images]
    imageio.mimsave(gif_path, images_pil, duration=duration, loop=loop)
    return gif_path
    
def process_image(image, grid_cols_input, grid_rows_input):
    frames = split_image_grid(image, grid_cols_input, grid_rows_input)
    zip_file = zip_images(frames)
    return zip_file

def process_image_to_gif(image, grid_cols_input, grid_rows_input, fps_input):
    frames = split_image_grid(image, grid_cols_input, grid_rows_input)
    interpolated_frames = interpolate_frames(frames, factor=2)
    enhanced_frames = enhance_gif(interpolated_frames)
    gif_file = create_gif_imageio(enhanced_frames, fps=fps_input, loop=0)
    
    # Preview the first frame of the GIF
    # preview_image = Image.fromarray(enhanced_frames[0])
    # preview_image.save(os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif"))
    preview_path = os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif")
    
    return preview_path, gif_file
    
def zip_images(images):
    zip_path = os.path.join(GENERATED_FILES_DIR, "output.zip")
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for idx, img in enumerate(images):
            img_buffer = BytesIO()
            img = Image.fromarray(img)
            img.save(img_buffer, format='PNG')
            img_buffer.seek(0)
            zipf.writestr(f'image_{idx}.png', img_buffer.getvalue())
    return zip_path

with gr.Blocks() as demo:
    with gr.Row():
        image_input = gr.Image(label="Input Image", type="pil")
        grid_cols_input = gr.Slider(1, 10, value=2, step=1, label="Grid Columns")
        grid_rows_input = gr.Slider(1, 10, value=2, step=1, label="Grid Rows")
        fps_input = gr.Slider(1, 30, value=10, step=1, label="FPS (Frames per Second)")
    with gr.Row():
        zip_button = gr.Button("Create Zip File")
        gif_button = gr.Button("Create GIF")
    with gr.Row():
        preview_image = gr.Image(label="Preview GIF Frame")
        zip_output = gr.File(label="Download Zip File")
        gif_output = gr.File(label="Download GIF")
    zip_button.click(process_image, inputs=[image_input, grid_cols_input, grid_rows_input], outputs=zip_output)
    gif_button.click(
        process_image_to_gif, 
        inputs=[image_input, grid_cols_input, grid_rows_input, fps_input], 
        outputs=[preview_image, gif_output]
    )

demo.launch(show_error=True)