Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,8 @@ from PIL import Image
|
|
| 14 |
|
| 15 |
# Setup and initialization code
|
| 16 |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
| 17 |
-
|
|
|
|
| 18 |
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
| 19 |
os.environ["HF_HUB_CACHE"] = cache_path
|
| 20 |
os.environ["HF_HOME"] = cache_path
|
|
@@ -110,19 +111,6 @@ footer {display: none !important}
|
|
| 110 |
border-radius: 4px !important;
|
| 111 |
transition: transform 0.2s;
|
| 112 |
}
|
| 113 |
-
/* Force gallery items to maintain aspect ratio */
|
| 114 |
-
.gallery-item {
|
| 115 |
-
width: 100% !important;
|
| 116 |
-
aspect-ratio: 1 !important;
|
| 117 |
-
overflow: hidden !important;
|
| 118 |
-
}
|
| 119 |
-
.gallery-item img {
|
| 120 |
-
width: 100% !important;
|
| 121 |
-
height: 100% !important;
|
| 122 |
-
object-fit: cover !important;
|
| 123 |
-
border-radius: 4px;
|
| 124 |
-
transition: transform 0.2s;
|
| 125 |
-
}
|
| 126 |
.gallery-item img:hover {
|
| 127 |
transform: scale(1.05);
|
| 128 |
}
|
|
@@ -161,30 +149,60 @@ footer {display: none !important}
|
|
| 161 |
|
| 162 |
def save_image(image):
|
| 163 |
"""Save the generated image and return the path"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 165 |
-
|
|
|
|
| 166 |
filepath = os.path.join(gallery_path, filename)
|
| 167 |
|
| 168 |
-
|
| 169 |
-
image.
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
def load_gallery():
|
| 177 |
"""Load all images from the gallery directory"""
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# Create Gradio interface
|
| 183 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 184 |
gr.HTML('<div class="title">AI Image Generator</div>')
|
| 185 |
gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
|
| 186 |
|
| 187 |
-
|
| 188 |
with gr.Row():
|
| 189 |
with gr.Column(scale=3):
|
| 190 |
prompt = gr.Textbox(
|
|
@@ -269,12 +287,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
| 269 |
""")
|
| 270 |
|
| 271 |
with gr.Column(scale=4, elem_classes=["fixed-width"]):
|
| 272 |
-
|
| 273 |
output = gr.Image(
|
| 274 |
label="Generated Image",
|
| 275 |
elem_id="output-image",
|
| 276 |
elem_classes=["output-image", "fixed-width"]
|
| 277 |
)
|
|
|
|
|
|
|
| 278 |
gallery = gr.Gallery(
|
| 279 |
label="Generated Images Gallery",
|
| 280 |
show_label=True,
|
|
@@ -285,9 +305,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
| 285 |
object_fit="cover",
|
| 286 |
elem_classes=["gallery-container", "fixed-width"]
|
| 287 |
)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
|
| 292 |
# Load existing gallery images on startup
|
| 293 |
gallery.value = load_gallery()
|
|
@@ -296,21 +313,27 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
| 296 |
def process_and_save_image(height, width, steps, scales, prompt, seed):
|
| 297 |
global pipe
|
| 298 |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
# Connect the generation button to both the image output and gallery update
|
| 316 |
def update_seed():
|
|
|
|
| 14 |
|
| 15 |
# Setup and initialization code
|
| 16 |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
| 17 |
+
# Change gallery path to user's home directory for persistence
|
| 18 |
+
gallery_path = path.join(os.path.expanduser("~"), "ai_generated_images")
|
| 19 |
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
| 20 |
os.environ["HF_HUB_CACHE"] = cache_path
|
| 21 |
os.environ["HF_HOME"] = cache_path
|
|
|
|
| 111 |
border-radius: 4px !important;
|
| 112 |
transition: transform 0.2s;
|
| 113 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
.gallery-item img:hover {
|
| 115 |
transform: scale(1.05);
|
| 116 |
}
|
|
|
|
| 149 |
|
| 150 |
def save_image(image):
|
| 151 |
"""Save the generated image and return the path"""
|
| 152 |
+
# Ensure gallery directory exists
|
| 153 |
+
os.makedirs(gallery_path, exist_ok=True)
|
| 154 |
+
|
| 155 |
+
# Generate unique filename with timestamp and random suffix
|
| 156 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 157 |
+
random_suffix = os.urandom(4).hex()
|
| 158 |
+
filename = f"generated_{timestamp}_{random_suffix}.png"
|
| 159 |
filepath = os.path.join(gallery_path, filename)
|
| 160 |
|
| 161 |
+
try:
|
| 162 |
+
if isinstance(image, Image.Image):
|
| 163 |
+
# Save with maximum quality
|
| 164 |
+
image.save(filepath, "PNG", quality=100)
|
| 165 |
+
else:
|
| 166 |
+
image = Image.fromarray(image)
|
| 167 |
+
image.save(filepath, "PNG", quality=100)
|
| 168 |
+
|
| 169 |
+
# Verify the file was saved correctly
|
| 170 |
+
if not os.path.exists(filepath):
|
| 171 |
+
print(f"Warning: Failed to verify saved image at {filepath}")
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
return filepath
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error saving image: {str(e)}")
|
| 177 |
+
return None
|
| 178 |
|
| 179 |
def load_gallery():
|
| 180 |
"""Load all images from the gallery directory"""
|
| 181 |
+
try:
|
| 182 |
+
# Ensure gallery directory exists
|
| 183 |
+
os.makedirs(gallery_path, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
# Get all image files and sort by modification time
|
| 186 |
+
image_files = []
|
| 187 |
+
for f in os.listdir(gallery_path):
|
| 188 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 189 |
+
full_path = os.path.join(gallery_path, f)
|
| 190 |
+
image_files.append((full_path, os.path.getmtime(full_path)))
|
| 191 |
+
|
| 192 |
+
# Sort by modification time (newest first)
|
| 193 |
+
image_files.sort(key=lambda x: x[1], reverse=True)
|
| 194 |
+
|
| 195 |
+
# Return only the file paths
|
| 196 |
+
return [f[0] for f in image_files]
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Error loading gallery: {str(e)}")
|
| 199 |
+
return []
|
| 200 |
|
| 201 |
# Create Gradio interface
|
| 202 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 203 |
gr.HTML('<div class="title">AI Image Generator</div>')
|
| 204 |
gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
|
| 205 |
|
|
|
|
| 206 |
with gr.Row():
|
| 207 |
with gr.Column(scale=3):
|
| 208 |
prompt = gr.Textbox(
|
|
|
|
| 287 |
""")
|
| 288 |
|
| 289 |
with gr.Column(scale=4, elem_classes=["fixed-width"]):
|
| 290 |
+
# Current generated image
|
| 291 |
output = gr.Image(
|
| 292 |
label="Generated Image",
|
| 293 |
elem_id="output-image",
|
| 294 |
elem_classes=["output-image", "fixed-width"]
|
| 295 |
)
|
| 296 |
+
|
| 297 |
+
# Gallery of generated images
|
| 298 |
gallery = gr.Gallery(
|
| 299 |
label="Generated Images Gallery",
|
| 300 |
show_label=True,
|
|
|
|
| 305 |
object_fit="cover",
|
| 306 |
elem_classes=["gallery-container", "fixed-width"]
|
| 307 |
)
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
# Load existing gallery images on startup
|
| 310 |
gallery.value = load_gallery()
|
|
|
|
| 313 |
def process_and_save_image(height, width, steps, scales, prompt, seed):
|
| 314 |
global pipe
|
| 315 |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
| 316 |
+
try:
|
| 317 |
+
generated_image = pipe(
|
| 318 |
+
prompt=[prompt],
|
| 319 |
+
generator=torch.Generator().manual_seed(int(seed)),
|
| 320 |
+
num_inference_steps=int(steps),
|
| 321 |
+
guidance_scale=float(scales),
|
| 322 |
+
height=int(height),
|
| 323 |
+
width=int(width),
|
| 324 |
+
max_sequence_length=256
|
| 325 |
+
).images[0]
|
| 326 |
+
|
| 327 |
+
# Save the generated image
|
| 328 |
+
saved_path = save_image(generated_image)
|
| 329 |
+
if saved_path is None:
|
| 330 |
+
print("Warning: Failed to save generated image")
|
| 331 |
+
|
| 332 |
+
# Return both the generated image and updated gallery
|
| 333 |
+
return generated_image, load_gallery()
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"Error in image generation: {str(e)}")
|
| 336 |
+
return None, load_gallery()
|
| 337 |
|
| 338 |
# Connect the generation button to both the image output and gallery update
|
| 339 |
def update_seed():
|