santhosh97's picture
Update app.py
ada7876
raw
history blame
17.5 kB
from __future__ import annotations
import boto3
from botocore.config import Config
from dotenv import load_dotenv
import os
import shutil
from typing import List, Tuple, TYPE_CHECKING
import uuid
import argparse
import logging
from enum import Enum
import tempfile
from pathlib import Path
import requests
import banana_dev as banana
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
import smart_open
if TYPE_CHECKING:
from io import BytesIO
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Looks for .env file in current directory to pull environment variables. Should
# not overwrite already set environment variables. Used for S3 credentials.
load_dotenv()
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
_GRETEL_USERINFO_ENDPOINT = "https://api.gretel.cloud/users/me"
class UxState(str, Enum):
LOGIN_VIA_API_KEY = "login_via_api_key"
UPLOAD1 = "upload1"
UPLOAD2 = "upload2"
UPLOAD3 = "upload3"
PROMPT = "prompt"
TRAIN = "train"
FINISHED = "finished"
# Command-line arguments to control some stuff for easier local testing.
# Eventually may want to move everything into functions and have a
# if __name__ == "main" setup instead of everything inline.
def setup_session_state():
if "key" not in st.session_state:
st.session_state["key"] = uuid.uuid4().hex
if "ux_state" not in st.session_state:
st.session_state["ux_state"] = UxState.LOGIN_VIA_API_KEY
if "model_inputs" not in st.session_state:
st.session_state["model_inputs"] = None
if "concepts" not in st.session_state:
st.session_state["concepts"] = []
if "prompt_keywords" not in st.session_state:
st.session_state["prompt_keywords"] = None
if "prompt" not in st.session_state:
st.session_state["prompt"] = None
if "view" not in st.session_state:
st.session_state["view"] = False
if "user_email" not in st.session_state:
st.session_state["user_email"] = None
if "user_firstname" not in st.session_state:
st.session_state["user_firstname"] = None
if "user_verified" not in st.session_state:
st.session_state["user_verified"] = False
def bucket_parts(s3_path: str) -> Tuple[str, str]:
"""Split an S3 path into bucket and key.
Args:
s3_path: path starting with "s3:"
Returns:
Tuple of bucket and key for the path
"""
parts = s3_path.split("/")
bucket = parts[2]
key = "/".join(parts[3:])
return bucket, key
def generate_s3_get_url(s3_path: str, expiration_seconds: int) -> str:
"""Generate a presigned S3 url to read from an S3 path.
A presigned url allows anyone accessing that url to read the s3 path without
needing s3 credentials until the url expires.
Args:
s3_path: path starting with "s3:"
expiration_seconds: how long the url will be valid (does not influence
lifetime of the underlying s3 object, only the presigned url)
Returns:
The presigned url
"""
bucket, key = bucket_parts(s3_path)
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"}))
download_url = s3_client.generate_presigned_url(
"get_object",
Params={
"Bucket": bucket,
"Key": key
},
ExpiresIn=expiration_seconds
)
return download_url
def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
"""Generate a presigned S3 url to write to an S3 path.
A presigned url allows anyone accessing that url to write to the s3 path
without needing s3 credentials until the url expires.
Args:
s3_path: path starting with "s3:"
expiration_seconds: how long the url will be valid (does not influence
lifetime of the underlying s3 object, only the presigned url)
Returns:
The presigned url
"""
bucket, key = bucket_parts(s3_path)
s3_client = boto3.client("s3", config=Config(signature_version="s3v4", s3={"addressing_style": "path"}))
upload_url = s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": bucket,
"Key": key
},
ExpiresIn=expiration_seconds
)
return upload_url
def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
"""Save images as zip file to s3 for use in backend.
Blocks until images are processed, added to zip file, and uploaded to S3.
Args:
identifier: unique identifier for the run, used in s3 link
uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
image_type: string to identify different batches of images used in the
backend model/training. Currently used values: "face", "theme"
Returns:
S3 location of zip file containing png images.
"""
with tempfile.TemporaryDirectory() as temp_dir_name:
logger.info(f"Working from temp dir to zip and upload images: {temp_dir_name}")
temp_dir = Path(temp_dir_name)
if not os.path.exists(temp_dir / identifier):
os.makedirs(temp_dir / identifier)
logger.info("Processing uploaded images")
for num, uploaded_file in enumerate(uploaded_files):
file_ = Image.open(uploaded_file).convert("RGB")
file_.save(temp_dir / identifier / f"{num}_test.png")
local_zip_filestem = str(temp_dir / f"{identifier}_{image_type}_images")
logger.info("Making zip archive")
shutil.make_archive(local_zip_filestem, "zip", temp_dir / identifier)
local_zip_filename = f"{local_zip_filestem}.zip"
logger.info("Uploading zip file to s3")
# TODO: can we define expiration when making the s3 path?
# Probably if we use the boto3 library instead of smart open
s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type)
with open(local_zip_filename, "rb") as fin:
with smart_open.open(s3_path, "wb") as fout:
fout.write(fin.read())
logger.info(f"Completed upload to {s3_path}")
return s3_path
def train_model(model_inputs):
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
model_key = "bd2c55f5-84bb-40f9-82fb-196ca68b1c1d"
st.markdown(str(model_inputs))
_ = banana.run(api_key, model_key, model_inputs)
def switch_ux_state(new_state: UxState):
st.session_state['ux_state'] = new_state
st.experimental_rerun()
def run_enter_api_key():
api_key_input = st.empty()
with api_key_input.form(key='user_auth_api_key'):
api_key_input = st.text_input(label='Please enter your Gretel API Key', type='password')
st.caption("Don't have a Gretel Cloud account yet? [Sign up](https://gretel.ai/signup) for free now!")
submit_button = st.form_submit_button(label='Submit', type='primary')
if submit_button:
r = requests.get(_GRETEL_USERINFO_ENDPOINT, headers={'authorization': api_key_input})
if r.status_code != 200:
st.error('API key could not be verified')
return
me = r.json().get('data', {}).get('me', {})
email = me.get('email')
if email is None:
st.error('No e-mail associated with this API key')
return
st.session_state["user_email"] = email
st.session_state["user_firstname"] = me.get('firstname')
st.session_state["user_verified"] = True
switch_ux_state(UxState.UPLOAD1)
def run_upload_initial():
identifier = st.session_state["key"]
images = st.empty()
with images.form("concept_one_form"):
uploaded_files = st.file_uploader(
"Choose first concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
)
token = st.text_input("Token Name")
st.caption(
"""
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
"""
)
class_token = st.text_input("Token Class")
st.caption(
"""
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
"""
)
concept = st.checkbox(
'Would you like to fine-tune on a second concept?',
)
submitted = st.form_submit_button(f"Upload")
if submitted:
with st.spinner('Uploading...'):
concept_information_dictionary = {
"file_path": generate_s3_get_url(zip_and_upload_images(
identifier, uploaded_files, "concept_one"), expiration_seconds=3600),
"token": token,
"class_token": class_token
}
st.session_state["concepts"].append(concept_information_dictionary)
st.success(f'Uploading {len(uploaded_files)} files done!')
if concept:
switch_ux_state(UxState.UPLOAD2)
else:
switch_ux_state(UxState.PROMPT)
def run_upload_secondary():
identifier = st.session_state["key"]
images = st.empty()
with images.form("concept_two_form"):
uploaded_files = st.file_uploader(
"Choose second concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
)
token = st.text_input("Token Name")
st.caption(
"""
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
"""
)
class_token = st.text_input("Token Class")
st.caption(
"""
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
"""
)
next_concept = st.checkbox(
'Would you like to fine-tune on a third concept?',
)
submitted = st.form_submit_button(f"Upload")
if submitted:
with st.spinner('Uploading...'):
concept_information_dictionary = {
"file_path": generate_s3_get_url(zip_and_upload_images(
identifier, uploaded_files, "concept_two"), expiration_seconds=3600),
"token": token,
"class_token": class_token
}
st.session_state["concepts"].append(concept_information_dictionary)
st.success(f'Uploading {len(uploaded_files)} files done!')
if next_concept:
switch_ux_state(UxState.UPLOAD3)
else:
switch_ux_state(UxState.PROMPT)
def run_upload_third():
identifier = st.session_state["key"]
images = st.empty()
with images.form("concept_three_form"):
uploaded_files = st.file_uploader(
"Choose third concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
)
token = st.text_input("Token Name")
st.caption(
"""
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
"""
)
class_token = st.text_input(f"Token Class")
st.caption(
"""
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
"""
)
submitted = st.form_submit_button(f"Upload")
if submitted:
with st.spinner('Uploading...'):
concept_information_dictionary = {
"file_path": generate_s3_get_url(zip_and_upload_images(
identifier, uploaded_files, "concept_three"), expiration_seconds=3600),
"token": token,
"class_token": class_token
}
st.session_state["concepts"].append(concept_information_dictionary)
st.success(f'Uploading {len(uploaded_files)} files done!')
switch_ux_state(UxState.PROMPT)
def run_prompts():
identifier = st.session_state["key"]
prompt_form = st.empty()
with prompt_form.form("prompt_form"):
#prompt = st.text_input("Token Name")
full_prompt = st.text_input("Prompt")
prompt_keywords = st.text_input(f"Prompt Keywords")
submitted = st.form_submit_button(f"Submit")
if submitted:
st.session_state["prompt_keywords"] = prompt_keywords
st.session_state["prompt"] = full_prompt
st.session_state["ux_state"] = UxState.TRAIN
def run_train():
st.write("Congratulations, your model is training.")
st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
st.write("Closing this tab will not affect the ongoing image generation.")
with st.spinner("Training in progress..."):
st.session_state["model_inputs"] = {
"concepts": st.session_state["concepts"],
"num_images": 50,
"prompt": st.session_state["prompt"],
"prompt_keywords": st.session_state["prompt_keywords"]
}
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
st.session_state['model_inputs']['identifier'] = st.session_state["key"]
st.session_state['model_inputs']['email'] = st.session_state["user_email"]
# The backend does not have s3 credentials, so generate
# presigned urls for the backend to use to write and read
# the generated images.
st.session_state['model_inputs']['output_s3_url_get'] = generate_s3_get_url(
s3_output_path, expiration_seconds=60 * 60 * 24,
)
st.session_state['model_inputs']['output_s3_url_put'] = generate_s3_put_url(
s3_output_path, expiration_seconds=3600,
)
train_model(st.session_state['model_inputs'])
switch_ux_state(UxState.FINISHED)
def run_finished():
st.success('Image generation completed!')
st.write(f"We've sent an email to {st.session_state['user_email']} with a link to your generated images. Check it out!")
if __name__ == "__main__":
setup_session_state()
ux_state = st.session_state["ux_state"]
runners = {
UxState.LOGIN_VIA_API_KEY: run_enter_api_key,
UxState.UPLOAD1: run_upload_initial,
UxState.UPLOAD2: run_upload_secondary,
UxState.UPLOAD3: run_upload_third,
UxState.PROMPT: run_prompts,
UxState.TRAIN: run_train,
UxState.FINISHED: run_finished,
}
if (runner := runners.get(ux_state)) is not None:
runner()
else:
raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")