awacke1 commited on
Commit
27866da
1 Parent(s): feb60ec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import datetime
3
+ import gradio as gr
4
+ import numpy as np
5
+ import os
6
+ import pytz
7
+ import psutil
8
+ import re
9
+ import random
10
+ import torch
11
+ import time
12
+ import shutil
13
+ import zipfile
14
+ from PIL import Image
15
+ from io import BytesIO
16
+ from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
17
+
18
+ # ... [previous imports and setup code remains unchanged]
19
+
20
+ # New function to save prompt to history
21
+ def save_prompt_to_history(prompt):
22
+ with open("prompt_history.txt", "a") as f:
23
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
24
+ f.write(f"{timestamp}: {prompt}\n")
25
+
26
+ # Modified predict function
27
+ def predict(prompt, guidance, steps, seed=1231231):
28
+ generator = torch.manual_seed(seed)
29
+ last_time = time.time()
30
+ results = pipe(
31
+ prompt=prompt,
32
+ generator=generator,
33
+ num_inference_steps=steps,
34
+ guidance_scale=guidance,
35
+ width=512,
36
+ height=512,
37
+ output_type="pil",
38
+ )
39
+ print(f"Pipe took {time.time() - last_time} seconds")
40
+
41
+ # Save prompt to history
42
+ save_prompt_to_history(prompt)
43
+
44
+ # ... [rest of the function remains unchanged]
45
+
46
+ return results.images[0] if len(results.images) > 0 else None
47
+
48
+ # Modified save_all_images function
49
+ def save_all_images(images):
50
+ if len(images) == 0:
51
+ return None, None
52
+
53
+ timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
54
+ zip_filename = f"images_and_history_{timestamp}.zip"
55
+
56
+ with zipfile.ZipFile(zip_filename, 'w') as zipf:
57
+ # Add image files
58
+ for file in images:
59
+ zipf.write(file, os.path.basename(file))
60
+
61
+ # Add prompt history file
62
+ if os.path.exists("prompt_history.txt"):
63
+ zipf.write("prompt_history.txt")
64
+
65
+ # Generate download link
66
+ zip_base64 = encode_file_to_base64(zip_filename)
67
+ download_link = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_filename}">Download All (Images & History)</a>'
68
+
69
+ return zip_filename, download_link
70
+
71
+ # Function to read prompt history
72
+ def read_prompt_history():
73
+ if os.path.exists("prompt_history.txt"):
74
+ with open("prompt_history.txt", "r") as f:
75
+ return f.read()
76
+ return "No prompts yet."
77
+
78
+ # Modified Gradio interface
79
+ with gr.Blocks(css=css) as demo:
80
+ with gr.Column(elem_id="container"):
81
+ # ... [previous UI components remain unchanged]
82
+
83
+ # Add prompt history display
84
+ with gr.Accordion("Prompt History", open=False):
85
+ prompt_history = gr.Code(label="Prompt History", language="text", interactive=False)
86
+
87
+ # ... [rest of the UI components]
88
+
89
+ # Function to update prompt history display
90
+ def update_prompt_history():
91
+ return read_prompt_history()
92
+
93
+ # Connect components
94
+ generate_bt.click(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
95
+ prompt.submit(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
96
+
97
+ # Update prompt history when generating image or when accordion is opened
98
+ generate_bt.click(fn=update_prompt_history, outputs=prompt_history)
99
+ prompt.submit(fn=update_prompt_history, outputs=prompt_history)
100
+
101
+ # Modify save_all_button click event
102
+ save_all_button.click(
103
+ fn=lambda: save_all_images([f for f in os.listdir() if f.lower().endswith((".png", ".jpg", ".jpeg"))]),
104
+ outputs=[gr.File(), gr.HTML()]
105
+ )
106
+
107
+ demo.queue()
108
+ demo.launch(allowed_paths=["/"])