Spaces:
Sleeping
Sleeping
joelorellana
commited on
Commit
·
1c681f7
1
Parent(s):
70e3e61
first commit for the project
Browse files- .gitignore +3 -0
- .streamlit/config.toml +6 -0
- app.py +83 -1
- dalle_generate_img.py +3 -4
- finetune_generate_img.py +3 -4
- gpt_vision_prompt.py +3 -0
- midjourney_generate_img.py +21 -2
- requirements.txt +124 -0
- stability_generate_img.py +11 -13
- test.py +86 -0
.gitignore
CHANGED
@@ -210,5 +210,8 @@ pyrightconfig.json
|
|
210 |
config.py
|
211 |
img/
|
212 |
output_img/
|
|
|
|
|
|
|
213 |
|
214 |
# End of https://www.toptal.com/developers/gitignore/api/python,macos
|
|
|
210 |
config.py
|
211 |
img/
|
212 |
output_img/
|
213 |
+
test.py
|
214 |
+
test2.py
|
215 |
+
|
216 |
|
217 |
# End of https://www.toptal.com/developers/gitignore/api/python,macos
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor="#E694FF"
|
3 |
+
backgroundColor="#0E1117"
|
4 |
+
secondaryBackgroundColor="#31333F"
|
5 |
+
textColor="#FAFAFA"
|
6 |
+
font="sans serif"
|
app.py
CHANGED
@@ -1 +1,83 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from gpt_vision_prompt import generate_prompt_with_vision
|
3 |
+
import tempfile
|
4 |
+
from dalle_generate_img import generate_img_with_dalle
|
5 |
+
from stability_generate_img import generate_image_with_stability
|
6 |
+
from finetune_generate_img import generate_finetuned_img
|
7 |
+
from midjourney_generate_img import midjourney_generate_img
|
8 |
+
|
9 |
+
# Page configuration
|
10 |
+
st.set_page_config(layout="wide")
|
11 |
+
st.sidebar.title("API Keys")
|
12 |
+
st.markdown("<h1 style='text-align: center; color: grey;'>Image Generation App</h1>", unsafe_allow_html=True)
|
13 |
+
st.text("Prepared by [email protected] for fomo.ai")
|
14 |
+
|
15 |
+
# List of API key names
|
16 |
+
api_key_names = ["OPENAI_API_KEY", "MIDJOURNEY_GOAPI_TOKEN", "REPLICATE_API_TOKEN", "STABILITY_API_KEY"]
|
17 |
+
|
18 |
+
# Initialize session state if it does not exist
|
19 |
+
if 'api_keys' not in st.session_state:
|
20 |
+
st.session_state['api_keys'] = {key_name: "" for key_name in api_key_names}
|
21 |
+
if 'editable_prompt' not in st.session_state:
|
22 |
+
st.session_state['editable_prompt'] = ""
|
23 |
+
if 'upload_completed' not in st.session_state:
|
24 |
+
st.session_state['upload_completed'] = False
|
25 |
+
|
26 |
+
# Define a function to request and update API keys
|
27 |
+
def request_and_update_api_keys():
|
28 |
+
all_keys_entered = True
|
29 |
+
for key_name in api_key_names:
|
30 |
+
key_value = st.sidebar.text_input(f"Enter {key_name}:", value=st.session_state['api_keys'].get(key_name, ""), type="password", key=key_name)
|
31 |
+
st.session_state['api_keys'][key_name] = key_value
|
32 |
+
if not key_value:
|
33 |
+
all_keys_entered = False
|
34 |
+
return all_keys_entered
|
35 |
+
|
36 |
+
all_keys_entered = request_and_update_api_keys()
|
37 |
+
|
38 |
+
# Check if all API keys have been entered
|
39 |
+
if all_keys_entered:
|
40 |
+
# Section to upload the image
|
41 |
+
uploaded_file = st.file_uploader("Upload Image to analyze", type=['jpg', 'jpeg', 'png'], on_change=lambda: setattr(st.session_state, 'upload_completed', True))
|
42 |
+
if uploaded_file is not None:
|
43 |
+
st.session_state['upload_completed'] = True
|
44 |
+
left_co, cent_co, _ = st.columns([1, 2, 1])
|
45 |
+
with cent_co:
|
46 |
+
st.image(uploaded_file, caption="Uploaded Image")
|
47 |
+
_, right_co = st.columns([5, 1])
|
48 |
+
if right_co.button("Generate Prompt"):
|
49 |
+
with st.spinner("Generating Prompt..."):
|
50 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
51 |
+
temp_file.write(uploaded_file.getvalue())
|
52 |
+
temp_path = temp_file.name
|
53 |
+
api_key = st.session_state['api_keys']['OPENAI_API_KEY']
|
54 |
+
prompt = generate_prompt_with_vision(temp_path, api_key=api_key)
|
55 |
+
st.session_state['editable_prompt'] = prompt
|
56 |
+
st.session_state['upload_completed'] = False
|
57 |
+
|
58 |
+
if st.session_state['upload_completed']:
|
59 |
+
# Maintain the content of the editable prompt after generating images
|
60 |
+
editable_prompt = st.text_area("Edit the prompt as needed:", value=st.session_state['editable_prompt'], placeholder="Enter your prompt here...", height=150, key='editable_prompt', on_change=lambda: st.session_state.update(editable_prompt=editable_prompt))
|
61 |
+
if st.button("Generate New Image", key='generate_image_btn'):
|
62 |
+
col1, col2, col3, col4 = st.columns(4)
|
63 |
+
with col1:
|
64 |
+
with st.spinner("Generating DALL·E Image..."):
|
65 |
+
result_path_1 = generate_img_with_dalle(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['OPENAI_API_KEY'])
|
66 |
+
st.image(result_path_1, caption="DALL·E Image")
|
67 |
+
with col2:
|
68 |
+
with st.spinner("Generating Stable Diffusion Image..."):
|
69 |
+
result_path_2 = generate_image_with_stability(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['STABILITY_API_KEY'])
|
70 |
+
st.image(result_path_2, caption="Stable Diffusion Image")
|
71 |
+
with col3:
|
72 |
+
with st.spinner("Generating Finetuning Image..."):
|
73 |
+
result_path_3 = generate_finetuned_img(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['REPLICATE_API_TOKEN'])
|
74 |
+
st.image(result_path_3, caption="Finetuned SDXL Image")
|
75 |
+
with col4:
|
76 |
+
with st.spinner("Generating Midjourney Image..."):
|
77 |
+
result_path_4 = midjourney_generate_img(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['MIDJOURNEY_GOAPI_TOKEN'])
|
78 |
+
st.image(result_path_4, caption="Midjourney Image")
|
79 |
+
# Update the prompt in session state to keep the text
|
80 |
+
st.session_state['editable_prompt'] = editable_prompt
|
81 |
+
st.session_state['upload_completed'] = False # Disable the text area and button after generating images
|
82 |
+
else:
|
83 |
+
st.warning('Please enter all required API keys to proceed.', icon="⚠️")
|
dalle_generate_img.py
CHANGED
@@ -13,10 +13,9 @@ import requests
|
|
13 |
from config import OPENAI_API_KEY
|
14 |
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
def generate_img_with_dalle(prompt="", ):
|
19 |
"""Generate an image using the DALL-E API"""
|
|
|
20 |
# DALL-E model parameters
|
21 |
size = '1024x1024' # Choose between '1024x1024', '512x512', '256x256'
|
22 |
quality = 'hd' # Choose between 'standard', 'hd'
|
@@ -36,4 +35,4 @@ def generate_img_with_dalle(prompt="", ):
|
|
36 |
img = Image.open(io.BytesIO(response.content))
|
37 |
img.save('output_img/dalle_generated_img.png') # Save the image as a .png file
|
38 |
print('Image saved in output_img/dalle_generated_img.png')
|
39 |
-
return "
|
|
|
13 |
from config import OPENAI_API_KEY
|
14 |
|
15 |
|
16 |
+
def generate_img_with_dalle(prompt="", api_key=OPENAI_API_KEY):
|
|
|
|
|
17 |
"""Generate an image using the DALL-E API"""
|
18 |
+
client = OpenAI(api_key=api_key)
|
19 |
# DALL-E model parameters
|
20 |
size = '1024x1024' # Choose between '1024x1024', '512x512', '256x256'
|
21 |
quality = 'hd' # Choose between 'standard', 'hd'
|
|
|
35 |
img = Image.open(io.BytesIO(response.content))
|
36 |
img.save('output_img/dalle_generated_img.png') # Save the image as a .png file
|
37 |
print('Image saved in output_img/dalle_generated_img.png')
|
38 |
+
return "output_img/dalle_generated_img.png"
|
finetune_generate_img.py
CHANGED
@@ -10,11 +10,8 @@ import requests
|
|
10 |
from PIL import Image
|
11 |
from config import REPLICATE_API_TOKEN
|
12 |
|
13 |
-
# Set up environment variables for Replicate API
|
14 |
-
os.environ['REPLICATE_API_TOKEN'] = REPLICATE_API_TOKEN
|
15 |
|
16 |
-
|
17 |
-
def generate_finetuned_img(prompt):
|
18 |
"""
|
19 |
Generate a finetuned image based on the given prompt.
|
20 |
|
@@ -24,6 +21,8 @@ def generate_finetuned_img(prompt):
|
|
24 |
Returns:
|
25 |
str: The file path of the saved finetuned image.
|
26 |
"""
|
|
|
|
|
27 |
# Create finetuned image
|
28 |
print('Creating finetuned image...')
|
29 |
output = replicate.run(
|
|
|
10 |
from PIL import Image
|
11 |
from config import REPLICATE_API_TOKEN
|
12 |
|
|
|
|
|
13 |
|
14 |
+
def generate_finetuned_img(prompt, api_key=REPLICATE_API_TOKEN):
|
|
|
15 |
"""
|
16 |
Generate a finetuned image based on the given prompt.
|
17 |
|
|
|
21 |
Returns:
|
22 |
str: The file path of the saved finetuned image.
|
23 |
"""
|
24 |
+
# Set up environment variables for Replicate API
|
25 |
+
os.environ['REPLICATE_API_TOKEN'] = api_key
|
26 |
# Create finetuned image
|
27 |
print('Creating finetuned image...')
|
28 |
output = replicate.run(
|
gpt_vision_prompt.py
CHANGED
@@ -55,4 +55,7 @@ def generate_prompt_with_vision(image_path, prompt=PROMPT, api_key=OPENAI_API_KE
|
|
55 |
headers=headers,
|
56 |
json=payload,
|
57 |
timeout=30)
|
|
|
|
|
|
|
58 |
return response.json()['choices'][0]['message']['content']
|
|
|
55 |
headers=headers,
|
56 |
json=payload,
|
57 |
timeout=30)
|
58 |
+
print(response.status_code)
|
59 |
+
print(response.text)
|
60 |
+
print(response.json())
|
61 |
return response.json()['choices'][0]['message']['content']
|
midjourney_generate_img.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
""" Generate an image using the Midjourney API"""
|
2 |
import io
|
|
|
3 |
import requests
|
4 |
from PIL import Image
|
5 |
from progress_bar import print_progress_bar
|
@@ -12,13 +13,14 @@ headers = {
|
|
12 |
"X-API-KEY": GOAPIKEY
|
13 |
}
|
14 |
|
15 |
-
def midjourney_generate_img(prompt):
|
16 |
"""Generate an image using the Midjourney API
|
17 |
|
18 |
Keyword arguments:
|
19 |
prompt -- The prompt to generate the image from
|
20 |
Return: An image saved in a .png file
|
21 |
"""
|
|
|
22 |
img_generation_data = {
|
23 |
"prompt": prompt,
|
24 |
"aspect_ratio": "16:9",
|
@@ -66,4 +68,21 @@ def midjourney_generate_img(prompt):
|
|
66 |
img = Image.open(io.BytesIO(image_response.content))
|
67 |
img.save('output_img/midjourney_generated_img.png')
|
68 |
print("Image saved in output_img/midjourney_generated_img.png")
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
""" Generate an image using the Midjourney API"""
|
2 |
import io
|
3 |
+
import random
|
4 |
import requests
|
5 |
from PIL import Image
|
6 |
from progress_bar import print_progress_bar
|
|
|
13 |
"X-API-KEY": GOAPIKEY
|
14 |
}
|
15 |
|
16 |
+
def midjourney_generate_img(prompt, api_key=GOAPIKEY):
|
17 |
"""Generate an image using the Midjourney API
|
18 |
|
19 |
Keyword arguments:
|
20 |
prompt -- The prompt to generate the image from
|
21 |
Return: An image saved in a .png file
|
22 |
"""
|
23 |
+
headers["X-API-KEY"] = api_key
|
24 |
img_generation_data = {
|
25 |
"prompt": prompt,
|
26 |
"aspect_ratio": "16:9",
|
|
|
68 |
img = Image.open(io.BytesIO(image_response.content))
|
69 |
img.save('output_img/midjourney_generated_img.png')
|
70 |
print("Image saved in output_img/midjourney_generated_img.png")
|
71 |
+
|
72 |
+
# divide img by 4 and save only one part
|
73 |
+
img_width, img_height = img.size
|
74 |
+
target_width = img_width // 2
|
75 |
+
target_height = img_height // 2
|
76 |
+
part = random.randint(1, 4) # select a random part
|
77 |
+
if part == 1:
|
78 |
+
img_cropped = img.crop((0, 0, target_width, target_height)) # Superior izquierda
|
79 |
+
elif part == 2:
|
80 |
+
img_cropped = img.crop((target_width, 0, img_width, target_height)) # Superior derecha
|
81 |
+
elif part == 3:
|
82 |
+
img_cropped = img.crop((0, target_height, target_width, img_height)) # Inferior izquierda
|
83 |
+
else:
|
84 |
+
img_cropped = img.crop((target_width, target_height, img_width, img_height)) # Inferior derecha
|
85 |
+
# save the selected img
|
86 |
+
img_cropped.save('output_img/midjourney_single_img.png')
|
87 |
+
print("Single image saved in output_img/midjourney_single_img.png")
|
88 |
+
return "output_img/midjourney_single_img.png"
|
requirements.txt
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
adbc_driver_manager==0.10.0
|
2 |
+
adbc_driver_postgresql==0.10.0
|
3 |
+
adbc_driver_sqlite==0.10.0
|
4 |
+
altair_saver==0.5.0
|
5 |
+
altair_viewer==0.4.0
|
6 |
+
anywidget==0.9.2
|
7 |
+
AppKit==0.2.8
|
8 |
+
atheris==2.3.0
|
9 |
+
beautifulsoup4==4.12.3
|
10 |
+
bokeh==3.2.1
|
11 |
+
boto3==1.34.54
|
12 |
+
botocore==1.29.76
|
13 |
+
brotlicffi==1.1.0.0
|
14 |
+
brotlipy==0.7.0
|
15 |
+
cached_property==1.5.2
|
16 |
+
chart_studio==1.1.0
|
17 |
+
ConfigParser==6.0.1
|
18 |
+
contextlib2==21.6.0
|
19 |
+
cryptography==41.0.3
|
20 |
+
ctypes_snappy==1.03
|
21 |
+
curio==1.6
|
22 |
+
cycler==0.11.0
|
23 |
+
Cython==3.0.8
|
24 |
+
cytoolz==0.12.0
|
25 |
+
defusedxml==0.7.1
|
26 |
+
diffusers==0.26.3
|
27 |
+
disco==1.40.4
|
28 |
+
dl==0.1.0
|
29 |
+
docutils==0.18.1
|
30 |
+
docutils==0.18.1
|
31 |
+
email_validator==2.1.1
|
32 |
+
eval_type_backport==0.1.3
|
33 |
+
exceptiongroup==1.2.0
|
34 |
+
fastparquet==2024.2.0
|
35 |
+
filelock==3.13.1
|
36 |
+
Foundation==0.1.0a0.dev1
|
37 |
+
fqdn==1.5.1
|
38 |
+
fsspec==2023.4.0
|
39 |
+
gitdb_speedups==0.1.0
|
40 |
+
gradio==4.19.2
|
41 |
+
grpc_reflection==1.0.0
|
42 |
+
h2==4.1.0
|
43 |
+
HTMLParser==0.0.2
|
44 |
+
hypothesis==6.98.15
|
45 |
+
ipython==8.12.3
|
46 |
+
ipywidgets==8.0.4
|
47 |
+
isoduration==20.11.0
|
48 |
+
jnius==1.1.0
|
49 |
+
JPype1==1.5.0
|
50 |
+
jsonpointer==2.1
|
51 |
+
keyframed==0.3.15
|
52 |
+
keyring==23.13.1
|
53 |
+
linkify_it_py==2.0.0
|
54 |
+
lxml==4.9.3
|
55 |
+
lz4==4.3.2
|
56 |
+
matplotlib==3.7.2
|
57 |
+
moto==5.0.2
|
58 |
+
mtrand==0.1
|
59 |
+
numarray==1.5.1
|
60 |
+
Numeric==24.2
|
61 |
+
numexpr==2.8.4
|
62 |
+
odfpy==1.4.1
|
63 |
+
olefile==0.47
|
64 |
+
openpyxl==3.0.10
|
65 |
+
outcome==1.3.0.post0
|
66 |
+
pickle5==0.0.12
|
67 |
+
pkgutil_resolve_name==1.3.10
|
68 |
+
plotly==5.19.0
|
69 |
+
psutil==5.9.0
|
70 |
+
pycares==4.4.0
|
71 |
+
pycurl==7.45.2
|
72 |
+
PyInstaller==6.4.0
|
73 |
+
pynvml==11.5.0
|
74 |
+
pyobjc_framework_Cocoa==9.0
|
75 |
+
pyodide==0.0.2
|
76 |
+
pyOpenSSL==23.2.0
|
77 |
+
pyOpenSSL==24.0.0
|
78 |
+
pyperf==2.6.2
|
79 |
+
PyQt4==4.11.4
|
80 |
+
PyQt5==5.15.10
|
81 |
+
PyQt5_sip==12.11.0
|
82 |
+
PyQt6==6.6.1
|
83 |
+
PySide6==6.6.2
|
84 |
+
pytest==7.4.0
|
85 |
+
python_calamine==0.2.0
|
86 |
+
pyxlsb==1.0.10
|
87 |
+
PyYAML==6.0
|
88 |
+
PyYAML==6.0.1
|
89 |
+
QtPy==2.2.0
|
90 |
+
railroad==0.5.0
|
91 |
+
redis==5.0.2
|
92 |
+
rfc3339_validator==0.1.4
|
93 |
+
rfc3986_validator==0.1.1
|
94 |
+
rfc3987==1.3.8
|
95 |
+
s3fs==2023.4.0
|
96 |
+
scikit_learn==1.3.0
|
97 |
+
scipy==1.12.0
|
98 |
+
sets==0.3.2
|
99 |
+
setuptools_scm==8.0.4
|
100 |
+
simplejson==3.19.2
|
101 |
+
slack_sdk==3.27.1
|
102 |
+
snowflake==0.6.0
|
103 |
+
socksio==1.0.0
|
104 |
+
Sphinx==5.0.2
|
105 |
+
SQLAlchemy==1.4.39
|
106 |
+
sympy==1.11.1
|
107 |
+
tables==3.8.0
|
108 |
+
testbench==0.1.2
|
109 |
+
threadpoolctl==3.3.0
|
110 |
+
torch==2.2.1
|
111 |
+
traitlets==5.14.1
|
112 |
+
transformers==4.32.1
|
113 |
+
trove_classifiers==2024.2.23
|
114 |
+
uri_template==1.3.0
|
115 |
+
urllib3_secure_extra==0.1.0
|
116 |
+
uvloop==0.19.0
|
117 |
+
vegafusion==1.6.5
|
118 |
+
watchdog==2.1.6
|
119 |
+
webcolors==1.13
|
120 |
+
xarray==2023.6.0
|
121 |
+
xlrd==2.0.1
|
122 |
+
xlsxwriter==3.2.0
|
123 |
+
xmlrpclib==1.0.1
|
124 |
+
zstandard==0.19.0
|
stability_generate_img.py
CHANGED
@@ -9,24 +9,13 @@ Return: An image saved in a .png file
|
|
9 |
import os
|
10 |
import io
|
11 |
import warnings
|
12 |
-
|
13 |
from stability_sdk import client
|
14 |
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
|
15 |
from PIL import Image
|
16 |
from config import STABILITY_API_KEY
|
17 |
|
18 |
-
# Set up environment variables for Stability API
|
19 |
-
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
|
20 |
-
os.environ['STABILITY_KEY'] = STABILITY_API_KEY
|
21 |
-
|
22 |
-
# Set up our connection to the Stability API.
|
23 |
-
stability_api = client.StabilityInference(
|
24 |
-
key=os.environ['STABILITY_KEY'],
|
25 |
-
verbose=True,
|
26 |
-
engine="stable-diffusion-xl-1024-v1-0",
|
27 |
-
)
|
28 |
|
29 |
-
def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, width=1024, height=1024, samples=1):
|
30 |
"""
|
31 |
Generates an image based on the given prompt using Stability API.
|
32 |
|
@@ -39,6 +28,15 @@ def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, widt
|
|
39 |
:param samples: Number of images to generate.
|
40 |
:return: A PIL.Image object of the generated image.
|
41 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
print("Creating Stability Image...")
|
43 |
answers = stability_api.generate(
|
44 |
prompt=prompt,
|
@@ -63,6 +61,6 @@ def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, widt
|
|
63 |
img = Image.open(io.BytesIO(artifact.binary))
|
64 |
img.save("output_img/sd_generated_img.png")
|
65 |
print("Image saved in output_img/sd_generated_img.png")
|
66 |
-
return "
|
67 |
|
68 |
raise ValueError("No image was generated.")
|
|
|
9 |
import os
|
10 |
import io
|
11 |
import warnings
|
|
|
12 |
from stability_sdk import client
|
13 |
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
|
14 |
from PIL import Image
|
15 |
from config import STABILITY_API_KEY
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
def generate_image_with_stability(prompt, seed=42, steps=50, cfg_scale=7.0, width=1024, height=1024, samples=1, api_key=STABILITY_API_KEY):
|
19 |
"""
|
20 |
Generates an image based on the given prompt using Stability API.
|
21 |
|
|
|
28 |
:param samples: Number of images to generate.
|
29 |
:return: A PIL.Image object of the generated image.
|
30 |
"""
|
31 |
+
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
|
32 |
+
os.environ['STABILITY_KEY'] = api_key
|
33 |
+
# Set up our connection to the Stability API.
|
34 |
+
stability_api = client.StabilityInference(
|
35 |
+
key=os.environ['STABILITY_KEY'],
|
36 |
+
verbose=True,
|
37 |
+
engine="stable-diffusion-xl-1024-v1-0",
|
38 |
+
)
|
39 |
+
|
40 |
print("Creating Stability Image...")
|
41 |
answers = stability_api.generate(
|
42 |
prompt=prompt,
|
|
|
61 |
img = Image.open(io.BytesIO(artifact.binary))
|
62 |
img.save("output_img/sd_generated_img.png")
|
63 |
print("Image saved in output_img/sd_generated_img.png")
|
64 |
+
return "output_img/sd_generated_img.png"
|
65 |
|
66 |
raise ValueError("No image was generated.")
|
test.py
CHANGED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from gpt_vision_prompt import generate_prompt_with_vision
|
3 |
+
import tempfile
|
4 |
+
from dalle_generate_img import generate_img_with_dalle
|
5 |
+
from stability_generate_img import generate_image_with_stability
|
6 |
+
from finetune_generate_img import generate_finetuned_img
|
7 |
+
from midjourney_generate_img import midjourney_generate_img
|
8 |
+
|
9 |
+
|
10 |
+
# Configuración de la página
|
11 |
+
st.set_page_config(layout="wide")
|
12 |
+
st.sidebar.title("API Keys")
|
13 |
+
st.markdown("<h1 style='text-align: center; color: grey;'>Image Generation App</h1>", unsafe_allow_html=True)
|
14 |
+
st.text("Prepared by [email protected] for fomo.ai")
|
15 |
+
|
16 |
+
# Lista de nombres de las API keys
|
17 |
+
api_key_names = ["OPENAI_API_KEY", "MIDJOURNEY_GOAPI_TOKEN", "REPLICATE_API_TOKEN", "STABILITY_API_KEY"]
|
18 |
+
|
19 |
+
# Inicializar el estado de la sesión si no existe
|
20 |
+
if 'api_keys' not in st.session_state:
|
21 |
+
st.session_state['api_keys'] = {key_name: "" for key_name in api_key_names}
|
22 |
+
if 'editable_prompt' not in st.session_state:
|
23 |
+
st.session_state['editable_prompt'] = ""
|
24 |
+
|
25 |
+
# Definir una función para solicitar y actualizar las API keys
|
26 |
+
def request_and_update_api_keys():
|
27 |
+
all_keys_entered = True
|
28 |
+
for key_name in api_key_names:
|
29 |
+
key_value = st.sidebar.text_input(f"Enter {key_name}:", value=st.session_state['api_keys'].get(key_name, ""), type="password", key=key_name)
|
30 |
+
st.session_state['api_keys'][key_name] = key_value
|
31 |
+
if not key_value:
|
32 |
+
all_keys_entered = False
|
33 |
+
return all_keys_entered
|
34 |
+
|
35 |
+
all_keys_entered = request_and_update_api_keys()
|
36 |
+
|
37 |
+
# Revisar si todas las API keys han sido ingresadas
|
38 |
+
if all_keys_entered:
|
39 |
+
# Sección para subir la imagen
|
40 |
+
uploaded_file = st.file_uploader("Upload Image to analyze", type=['jpg', 'jpeg', 'png'])
|
41 |
+
if uploaded_file is not None:
|
42 |
+
left_co, cent_co,last_co = st.columns(3)
|
43 |
+
with cent_co:
|
44 |
+
st.image(uploaded_file, caption="Uploaded Image")
|
45 |
+
# Botón para generar el prompt solo si hay una imagen subida
|
46 |
+
if st.button("Generate Prompt"):
|
47 |
+
with st.spinner("Generating Prompt..."):
|
48 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
49 |
+
temp_file.write(uploaded_file.getvalue())
|
50 |
+
temp_path = temp_file.name
|
51 |
+
api_key = st.session_state['api_keys']['OPENAI_API_KEY']
|
52 |
+
prompt = generate_prompt_with_vision(temp_path, api_key=api_key)
|
53 |
+
st.success("Done!")
|
54 |
+
st.session_state['editable_prompt'] = prompt # Actualizar el prompt en el estado de la sesión
|
55 |
+
|
56 |
+
editable_prompt = st.text_area("Edit the prompt as needed:", placeholder="Enter your prompt here...", height=150, key='editable_prompt', label_visibility='hidden')
|
57 |
+
|
58 |
+
col1, col2, col3, col4 = st.columns(4)
|
59 |
+
|
60 |
+
if st.session_state['editable_prompt'] and st.button("Generate New Image"):
|
61 |
+
with col1:
|
62 |
+
with st.spinner("Generating DALL·E Image..."):
|
63 |
+
result_path_1 = generate_img_with_dalle(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['OPENAI_API_KEY'])
|
64 |
+
st.success("Generated DALL·E Image!")
|
65 |
+
st.image(result_path_1, caption="DALL·E Image")
|
66 |
+
|
67 |
+
with col2:
|
68 |
+
with st.spinner("Generating Stable Diffusion Image..."):
|
69 |
+
result_path_2 = generate_image_with_stability(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['STABILITY_API_KEY'])
|
70 |
+
st.success("Generated Stable Diffusion Image!")
|
71 |
+
st.image(result_path_2, caption="Stable Diffusion Image")
|
72 |
+
|
73 |
+
with col3:
|
74 |
+
with st.spinner("Generating Finetuning Image..."):
|
75 |
+
result_path_3 = generate_finetuned_img(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['REPLICATE_API_TOKEN'])
|
76 |
+
st.success("Generated Image using a finetuned model!")
|
77 |
+
st.image(result_path_3, caption="Finetuned SDXL Image")
|
78 |
+
|
79 |
+
with col4:
|
80 |
+
with st.spinner("Generating Midjourney Image..."):
|
81 |
+
result_path_4 = midjourney_generate_img(st.session_state['editable_prompt'], api_key=st.session_state['api_keys']['MIDJOURNEY_GOAPI_TOKEN'])
|
82 |
+
st.success("Generated Midjourney Image!")
|
83 |
+
st.image(result_path_4, caption="Midjourney Image")
|
84 |
+
|
85 |
+
else:
|
86 |
+
st.warning('Please enter all required API keys to proceed.', icon="⚠️")
|