File size: 5,571 Bytes
1cd7e2c
 
37e5f7e
1cd7e2c
61b9df0
1cd7e2c
 
 
0af4039
 
 
74a4765
 
 
 
 
1cd7e2c
 
 
dfa75b2
1cd7e2c
 
 
 
dfa75b2
 
 
 
 
 
1cd7e2c
 
 
0af4039
 
 
 
 
 
1cd7e2c
0af4039
 
 
 
 
 
dfa75b2
37e5f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b1dbeb
dfa75b2
 
 
 
 
3b1dbeb
 
73cb8f9
3b1dbeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37e5f7e
3b1dbeb
1cd7e2c
 
 
5fa41d5
dfa75b2
 
 
5fa41d5
3b1dbeb
f7a1102
 
3b1dbeb
 
 
61b9df0
37e5f7e
 
 
 
 
 
 
 
61b9df0
1cd7e2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import sys
import zipfile

import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv

from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

# file extension stuff
import requests
import tempfile
from io import BytesIO

load_dotenv()

openai_key = os.getenv("OPENAI_API_KEY")
pw_key = os.getenv("PW")

if openai_key == "<YOUR_OPENAI_API_KEY>":
    openai_key = ""

if pw_key == "<YOUR_PW>":
    pw_key = ""

if pw_key == "":
    sys.exit("Please Provide A Password in the Environment Variables")

if openai_key == "":
    sys.exit("Please Provide Your OpenAI API Key")

# Connect to MongoDB
uri = os.getenv("MONGO_URI")
mongo_client = MongoClient(uri, server_api=ServerApi('1'))

mongo_db = mongo_client.pdr
mongo_collection = mongo_db["images"]

# Send a ping to confirm a successful connection
try:
    mongo_client.admin.command('ping')
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

image_paths_global = []

def generate_images_wrapper(prompts, pw, model):
    global image_paths_global
    image_paths = generate_images(prompts, pw, model)
    image_paths_global = image_paths  # Store paths globally
    return image_paths  # You might want to return something else for display

def zip_images(image_paths):
    zip_file_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip').name
    with zipfile.ZipFile(zip_file_path, 'w') as zipf:
        for path in image_paths:
            zipf.write(path, arcname=os.path.basename(path))
            os.remove(path)  # Clean up the temp image file
    return zip_file_path

def download_all_images():
    global image_paths_global
    if not image_paths_global:
        raise gr.Error("No images to download.")
    zip_path = zip_images(image_paths_global)
    image_paths_global = []  # Reset the global variable
    return zip_path

def generate_images(prompts, pw, model):
    # add a conditional to check for a valid password
    if pw != os.getenv("PW"):
        # output an error message to the user in the gradio interface if password is invalid
        raise gr.Error("Invalid password. Please try again.")

    image_paths = []  # Initialize a list to hold paths of generated images
    # Split the prompts string into individual prompts based on comma separation
    prompts_list = prompts.split(';')
    for prompt in prompts_list:
        text = prompt.strip()  # Remove leading/trailing whitespace

        try:
            client = OpenAI(api_key=openai_key)

            response = client.images.generate(
                prompt=text,
                model=model, # dall-e-2 or dall-e-3
                quality="standard", # standard or hd
                size="512x512" if model == "dall-e-2" else "1024x1024", # varies for dalle-2 and dalle-3, see https://openai.com/pricing for resolutions
                n=1, # Number of images to generate
            )


            image_url = response.data[0].url

            try:
                mongo_collection.insert_one({"text": text, "model": model, "image_url": image_url})
            except Exception as e:
                print(e)
                raise gr.Error("An error occurred while saving the prompt to the database.")

            # create a temporary file to store the image with extension
            image_response = requests.get(image_url)
            if image_response.status_code == 200:
                # Use a temporary file to automatically clean up after the file is closed
                temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
                temp_file.write(image_response.content)
                temp_file.close()
                # return the file with extension for download
                # return temp_file.name
                # append the file with extension to the list of image paths
                print(temp_file.name)
                image_paths.append(temp_file.name)
            else:
                raise gr.Error("Failed to download the image.")
        except Exception as error:
            print(str(error))
            raise gr.Error(f"An error occurred while generating the image for: {prompt}")

    return image_paths


with gr.Blocks() as demo:
    gr.Markdown("# <center> Prompt de Resistance Image Generator</center>")
    gr.Markdown("**Instructions**: To use this service, please enter the password. Then generate an image from the prompt field below, then click the download arrow from the top right of the image to save it.")
    pw = gr.Textbox(label="Password", type="password",
      placeholder="Enter the password to unlock the service")
    text = gr.Textbox(label="What do you want to create?",
      placeholder="Enter your text and then click on the \"Image Generate\" button")

    model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
    btn = gr.Button("Generate Images")
    # output_image = gr.Image(label="Image Output")
    output_images = gr.Gallery(label="Image Outputs",columns=[3], rows=[1], object_fit="contain", height="auto",allow_preview=False)

    text.submit(fn=generate_images_wrapper, inputs=[text,pw,model], outputs=output_images, api_name="generate_image")
    btn.click(fn=generate_images_wrapper, inputs=[text,pw,model], outputs=output_images, api_name=False)

    download_all_btn = gr.Button("Download All")
    download_link = gr.File(label="Download Zip")
    download_all_btn.click(fn=download_all_images, inputs=[], outputs=download_link)



demo.launch(share=True)