Spaces:
Running
Running
import uuid | |
import comet_ml | |
import gradio as gr | |
import pandas as pd | |
from PIL import Image | |
from transformers import CLIPModel, CLIPProcessor | |
CLIP_MODEL_PATH = "openai/clip-vit-base-patch32" | |
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH) | |
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_PATH) | |
DESCRIPTION = """Glad to see you here π. | |
You can use this Space to log predictions to [Comet](https://www.comet.ml/site) from Spaces that use Text to Image Diffusion Models. | |
Keep track of all your prompts and generated images so that you remember the good ones! | |
Set your Comet credentials in the Comet Settings tab and create an Experiment for logging data. | |
If you want to continue logging to the same Experiment over multiple sessions, add in the | |
Then use the path to a Space to generate from in the Diffusion Model tab | |
""" | |
def start_experiment( | |
comet_api_key, | |
comet_workspace, | |
comet_project_name, | |
comet_experiment_name, | |
experiment, | |
): | |
if not comet_api_key: | |
experiment = None | |
return ( | |
experiment, | |
""" | |
Please add your API key in order to log your predictions to a Comet Experiment. | |
If you don't have a Comet account yet, you can sign up using the link below: | |
https://www.comet.ml/signup | |
""", | |
) | |
try: | |
if comet_experiment_name: | |
api_experiment = get_experiment( | |
{ | |
"api_key": comet_api_key, | |
"workspace": comet_workspace, | |
"project_name": comet_project_name, | |
"experiment": comet_experiment_name, | |
} | |
) | |
else: | |
api_experiment = comet_ml.APIExperiment( | |
api_key=comet_api_key, | |
workspace=comet_workspace, | |
project_name=comet_project_name, | |
) | |
experiment = { | |
"api_key": comet_api_key, | |
"workspace": comet_workspace, | |
"project_name": comet_project_name, | |
"experiment": api_experiment.name, | |
} | |
return experiment, f"Started {api_experiment.name}. Happy logging!π" | |
except Exception as e: | |
return None, e | |
def get_experiment(experiment_state): | |
try: | |
api_key = experiment_state.get("api_key") | |
workspace = experiment_state.get("workspace") | |
project = experiment_state.get("project_name") | |
experiment_name = experiment_state.get("experiment") | |
return comet_ml.API(api_key=api_key).get_experiment( | |
workspace=workspace, project_name=project, experiment=experiment_name | |
) | |
except Exception as e: | |
return None | |
def get_experiment_status(experiment_state): | |
experiment = get_experiment(experiment_state) | |
if experiment is not None: | |
name = experiment.name | |
return experiment_state, f"Currently logging to: {name}" | |
return experiment_state, f"No Experiments found" | |
def predict( | |
model, | |
prompt, | |
experiment_state, | |
): | |
io = gr.Interface.load(model) | |
image = io(prompt) | |
pil_image = Image.open(image) | |
inputs = clip_processor( | |
text=[prompt], | |
images=pil_image, | |
return_tensors="pt", | |
padding=True, | |
) | |
outputs = clip_model(**inputs) | |
clip_score = outputs.logits_per_image.item() / 100.0 | |
experiment = get_experiment(experiment_state) | |
if experiment is not None: | |
image_id = uuid.uuid4().hex | |
experiment.log_image(image, image_id) | |
asset = pd.DataFrame.from_records( | |
[ | |
{ | |
"prompt": prompt, | |
"model": model, | |
"clip_model": CLIP_MODEL_PATH, | |
"clip_score": round(clip_score, 3), | |
} | |
] | |
) | |
experiment.log_table(f"{image_id}.json", asset, orient="records") | |
return image, experiment_state | |
def start_interface(): | |
demo = gr.Blocks() | |
with demo: | |
description = gr.Markdown(DESCRIPTION) | |
with gr.Tabs(): | |
with gr.TabItem(label="Comet Settings"): | |
# credentials | |
comet_api_key = gr.Textbox( | |
label="Comet API Key", | |
placeholder="This is required if you'd like to create an Experiment", | |
) | |
comet_workspace = gr.Textbox(label="Comet Workspace") | |
comet_project_name = gr.Textbox(label="Comet Project Name") | |
comet_experiment_name = gr.Textbox( | |
label="Comet Experiment Name", | |
placeholder=( | |
"Set this if you'd like" | |
"to continue logging to an existing Experiment", | |
), | |
) | |
with gr.Row(): | |
start = gr.Button("Start Experiment", variant="primary") | |
status = gr.Button("Experiment Status") | |
status_output = gr.Textbox(label="Status") | |
experiment_state = gr.Variable(label="Experiment State") | |
start.click( | |
start_experiment, | |
inputs=[ | |
comet_api_key, | |
comet_workspace, | |
comet_project_name, | |
comet_experiment_name, | |
experiment_state, | |
], | |
outputs=[experiment_state, status_output], | |
) | |
status.click( | |
get_experiment_status, | |
inputs=[experiment_state], | |
outputs=[experiment_state, status_output], | |
) | |
with gr.TabItem(label="Diffusion Model"): | |
diff_description = gr.Markdown( | |
"""The Model must be a path to any Space that accepts | |
only text as input and produces an image as an output | |
""" | |
) | |
model = gr.Textbox(label="Model", value="spaces/valhalla/glide-text2im") | |
prompt = gr.Textbox(label="Prompt") | |
outputs = gr.Image(label="Image") | |
submit = gr.Button("Submit", variant="primary") | |
submit.click( | |
predict, | |
inputs=[model, prompt, experiment_state], | |
outputs=[outputs, experiment_state], | |
) | |
demo.launch() | |
start_interface() | |