Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import boto3 | |
from botocore.config import Config | |
from dotenv import load_dotenv | |
import os | |
import shutil | |
from typing import List, Tuple, TYPE_CHECKING | |
import uuid | |
import argparse | |
import logging | |
from enum import Enum | |
import tempfile | |
from pathlib import Path | |
import requests | |
import banana_dev as banana | |
import streamlit as st | |
from PIL import Image | |
from streamlit_image_select import image_select | |
import smart_open | |
if TYPE_CHECKING: | |
from io import BytesIO | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Looks for .env file in current directory to pull environment variables. Should | |
# not overwrite already set environment variables. Used for S3 credentials. | |
load_dotenv() | |
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip" | |
_GRETEL_USERINFO_ENDPOINT = "https://api.gretel.cloud/users/me" | |
class UxState(str, Enum): | |
LOGIN_VIA_API_KEY = "login_via_api_key" | |
UPLOAD1 = "upload1" | |
UPLOAD2 = "upload2" | |
UPLOAD3 = "upload3" | |
PROMPT = "prompt" | |
TRAIN = "train" | |
FINISHED = "finished" | |
# Command-line arguments to control some stuff for easier local testing. | |
# Eventually may want to move everything into functions and have a | |
# if __name__ == "main" setup instead of everything inline. | |
def setup_session_state(): | |
if "key" not in st.session_state: | |
st.session_state["key"] = uuid.uuid4().hex | |
if "ux_state" not in st.session_state: | |
st.session_state["ux_state"] = UxState.LOGIN_VIA_API_KEY | |
if "model_inputs" not in st.session_state: | |
st.session_state["model_inputs"] = None | |
if "concepts" not in st.session_state: | |
st.session_state["concepts"] = [] | |
if "prompt_keywords" not in st.session_state: | |
st.session_state["prompt_keywords"] = None | |
if "prompt" not in st.session_state: | |
st.session_state["prompt"] = None | |
if "view" not in st.session_state: | |
st.session_state["view"] = False | |
if "user_email" not in st.session_state: | |
st.session_state["user_email"] = None | |
if "user_firstname" not in st.session_state: | |
st.session_state["user_firstname"] = None | |
if "user_verified" not in st.session_state: | |
st.session_state["user_verified"] = False | |
def bucket_parts(s3_path: str) -> Tuple[str, str]: | |
"""Split an S3 path into bucket and key. | |
Args: | |
s3_path: path starting with "s3:" | |
Returns: | |
Tuple of bucket and key for the path | |
""" | |
parts = s3_path.split("/") | |
bucket = parts[2] | |
key = "/".join(parts[3:]) | |
return bucket, key | |
def generate_s3_get_url(s3_path: str, expiration_seconds: int) -> str: | |
"""Generate a presigned S3 url to read from an S3 path. | |
A presigned url allows anyone accessing that url to read the s3 path without | |
needing s3 credentials until the url expires. | |
Args: | |
s3_path: path starting with "s3:" | |
expiration_seconds: how long the url will be valid (does not influence | |
lifetime of the underlying s3 object, only the presigned url) | |
Returns: | |
The presigned url | |
""" | |
bucket, key = bucket_parts(s3_path) | |
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"})) | |
download_url = s3_client.generate_presigned_url( | |
"get_object", | |
Params={ | |
"Bucket": bucket, | |
"Key": key | |
}, | |
ExpiresIn=expiration_seconds | |
) | |
return download_url | |
def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str: | |
"""Generate a presigned S3 url to write to an S3 path. | |
A presigned url allows anyone accessing that url to write to the s3 path | |
without needing s3 credentials until the url expires. | |
Args: | |
s3_path: path starting with "s3:" | |
expiration_seconds: how long the url will be valid (does not influence | |
lifetime of the underlying s3 object, only the presigned url) | |
Returns: | |
The presigned url | |
""" | |
bucket, key = bucket_parts(s3_path) | |
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"})) | |
upload_url = s3_client.generate_presigned_url( | |
"put_object", | |
Params={ | |
"Bucket": bucket, | |
"Key": key | |
}, | |
ExpiresIn=expiration_seconds | |
) | |
return upload_url | |
def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str: | |
"""Save images as zip file to s3 for use in backend. | |
Blocks until images are processed, added to zip file, and uploaded to S3. | |
Args: | |
identifier: unique identifier for the run, used in s3 link | |
uploaded_files: BytesIO or UploadedFile from streamlit fileuploader | |
image_type: string to identify different batches of images used in the | |
backend model/training. Currently used values: "face", "theme" | |
Returns: | |
S3 location of zip file containing png images. | |
""" | |
with tempfile.TemporaryDirectory() as temp_dir_name: | |
logger.info(f"Working from temp dir to zip and upload images: {temp_dir_name}") | |
temp_dir = Path(temp_dir_name) | |
if not os.path.exists(temp_dir / identifier): | |
os.makedirs(temp_dir / identifier) | |
logger.info("Processing uploaded images") | |
for num, uploaded_file in enumerate(uploaded_files): | |
file_ = Image.open(uploaded_file).convert("RGB") | |
file_.save(temp_dir / identifier / f"{num}_test.png") | |
local_zip_filestem = str(temp_dir / f"{identifier}_{image_type}_images") | |
logger.info("Making zip archive") | |
shutil.make_archive(local_zip_filestem, "zip", temp_dir / identifier) | |
local_zip_filename = f"{local_zip_filestem}.zip" | |
logger.info("Uploading zip file to s3") | |
# TODO: can we define expiration when making the s3 path? | |
# Probably if we use the boto3 library instead of smart open | |
s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type) | |
with open(local_zip_filename, "rb") as fin: | |
with smart_open.open(s3_path, "wb") as fout: | |
fout.write(fin.read()) | |
logger.info(f"Completed upload to {s3_path}") | |
return s3_path | |
def train_model(model_inputs): | |
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e" | |
model_key = "bd2c55f5-84bb-40f9-82fb-196ca68b1c1d" | |
st.markdown(str(model_inputs)) | |
_ = banana.run(api_key, model_key, model_inputs) | |
def switch_ux_state(new_state: UxState): | |
st.session_state['ux_state'] = new_state | |
st.experimental_rerun() | |
def run_enter_api_key(): | |
api_key_input = st.empty() | |
with api_key_input.form(key='user_auth_api_key'): | |
api_key_input = st.text_input(label='Please enter your Gretel API Key', type='password') | |
st.caption("Don't have a Gretel Cloud account yet? [Sign up](https://gretel.ai/signup) for free now!") | |
submit_button = st.form_submit_button(label='Submit', type='primary') | |
if submit_button: | |
r = requests.get(_GRETEL_USERINFO_ENDPOINT, headers={'authorization': api_key_input}) | |
if r.status_code != 200: | |
st.error('API key could not be verified') | |
return | |
me = r.json().get('data', {}).get('me', {}) | |
email = me.get('email') | |
if email is None: | |
st.error('No e-mail associated with this API key') | |
return | |
st.session_state["user_email"] = email | |
st.session_state["user_firstname"] = me.get('firstname') | |
st.session_state["user_verified"] = True | |
switch_ux_state(UxState.UPLOAD1) | |
def run_upload_initial(): | |
identifier = st.session_state["key"] | |
images = st.empty() | |
with images.form("concept_one_form"): | |
uploaded_files = st.file_uploader( | |
"Choose first concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] | |
) | |
token = st.text_input("Token Name") | |
st.caption( | |
""" | |
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`. | |
""" | |
) | |
class_token = st.text_input("Token Class") | |
st.caption( | |
""" | |
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`. | |
""" | |
) | |
concept = st.checkbox( | |
'Would you like to fine-tune on a second concept?', | |
) | |
submitted = st.form_submit_button(f"Upload") | |
if submitted: | |
with st.spinner('Uploading...'): | |
concept_information_dictionary = { | |
"file_path": generate_s3_get_url(zip_and_upload_images( | |
identifier, uploaded_files, "concept_one"), expiration_seconds=3600), | |
"token": token, | |
"class_token": class_token | |
} | |
st.session_state["concepts"].append(concept_information_dictionary) | |
st.success(f'Uploading {len(uploaded_files)} files done!') | |
if concept: | |
switch_ux_state(UxState.UPLOAD2) | |
else: | |
switch_ux_state(UxState.PROMPT) | |
def run_upload_secondary(): | |
identifier = st.session_state["key"] | |
images = st.empty() | |
with images.form("concept_two_form"): | |
uploaded_files = st.file_uploader( | |
"Choose second concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] | |
) | |
token = st.text_input("Token Name") | |
st.caption( | |
""" | |
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`. | |
""" | |
) | |
class_token = st.text_input("Token Class") | |
st.caption( | |
""" | |
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`. | |
""" | |
) | |
next_concept = st.checkbox( | |
'Would you like to fine-tune on a third concept?', | |
) | |
submitted = st.form_submit_button(f"Upload") | |
if submitted: | |
with st.spinner('Uploading...'): | |
concept_information_dictionary = { | |
"file_path": generate_s3_get_url(zip_and_upload_images( | |
identifier, uploaded_files, "concept_two"), expiration_seconds=3600), | |
"token": token, | |
"class_token": class_token | |
} | |
st.session_state["concepts"].append(concept_information_dictionary) | |
st.success(f'Uploading {len(uploaded_files)} files done!') | |
if next_concept: | |
switch_ux_state(UxState.UPLOAD3) | |
else: | |
switch_ux_state(UxState.PROMPT) | |
def run_upload_third(): | |
identifier = st.session_state["key"] | |
images = st.empty() | |
with images.form("concept_three_form"): | |
uploaded_files = st.file_uploader( | |
"Choose third concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] | |
) | |
token = st.text_input("Token Name") | |
st.caption( | |
""" | |
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`. | |
""" | |
) | |
class_token = st.text_input(f"Token Class") | |
st.caption( | |
""" | |
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`. | |
""" | |
) | |
submitted = st.form_submit_button(f"Upload") | |
if submitted: | |
with st.spinner('Uploading...'): | |
concept_information_dictionary = { | |
"file_path": generate_s3_get_url(zip_and_upload_images( | |
identifier, uploaded_files, "concept_three"), expiration_seconds=3600), | |
"token": token, | |
"class_token": class_token | |
} | |
st.session_state["concepts"].append(concept_information_dictionary) | |
st.success(f'Uploading {len(uploaded_files)} files done!') | |
switch_ux_state(UxState.PROMPT) | |
def run_prompts(): | |
identifier = st.session_state["key"] | |
prompt_form = st.empty() | |
with prompt_form.form("prompt_form"): | |
#prompt = st.text_input("Token Name") | |
full_prompt = st.text_input("Prompt") | |
prompt_keywords = st.text_input(f"Prompt Keywords") | |
submitted = st.form_submit_button(f"Submit") | |
if submitted: | |
st.session_state["prompt_keywords"] = prompt_keywords | |
st.session_state["prompt"] = full_prompt | |
st.session_state["ux_state"] = UxState.TRAIN | |
def run_train(): | |
st.write("Congratulations, your model is training.") | |
st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.") | |
st.write("Closing this tab will not affect the ongoing image generation.") | |
with st.spinner("Training in progress..."): | |
st.session_state["model_inputs"] = { | |
"concepts": st.session_state["concepts"], | |
"num_images": 50, | |
"prompt": st.session_state["prompt"], | |
"prompt_keywords": st.session_state["prompt_keywords"] | |
} | |
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated") | |
st.session_state['model_inputs']['identifier'] = st.session_state["key"] | |
st.session_state['model_inputs']['email'] = st.session_state["user_email"] | |
# The backend does not have s3 credentials, so generate | |
# presigned urls for the backend to use to write and read | |
# the generated images. | |
st.session_state['model_inputs']['output_s3_url_get'] = generate_s3_get_url( | |
s3_output_path, expiration_seconds=60 * 60 * 24, | |
) | |
st.session_state['model_inputs']['output_s3_url_put'] = generate_s3_put_url( | |
s3_output_path, expiration_seconds=3600, | |
) | |
train_model(st.session_state['model_inputs']) | |
switch_ux_state(UxState.FINISHED) | |
def run_finished(): | |
st.success('Image generation completed!') | |
st.write(f"We've sent an email to {st.session_state['user_email']} with a link to your generated images. Check it out!") | |
if __name__ == "__main__": | |
setup_session_state() | |
ux_state = st.session_state["ux_state"] | |
runners = { | |
UxState.LOGIN_VIA_API_KEY: run_enter_api_key, | |
UxState.UPLOAD1: run_upload_initial, | |
UxState.UPLOAD2: run_upload_secondary, | |
UxState.UPLOAD3: run_upload_third, | |
UxState.PROMPT: run_prompts, | |
UxState.TRAIN: run_train, | |
UxState.FINISHED: run_finished, | |
} | |
if (runner := runners.get(ux_state)) is not None: | |
runner() | |
else: | |
raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'") | |