Spaces:
Running
Running
import os | |
import base64 | |
import numpy as np | |
from PIL import Image | |
import io | |
import requests | |
import replicate | |
from flask import Flask, request | |
import gradio as gr | |
import openai | |
from openai import OpenAI | |
from dotenv import load_dotenv, find_dotenv | |
import json | |
# Locate the .env file | |
dotenv_path = find_dotenv() | |
load_dotenv(dotenv_path) | |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') | |
REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN') | |
client = OpenAI() | |
def call_openai(pil_image): | |
# Save the PIL image to a bytes buffer | |
buffered = io.BytesIO() | |
pil_image.save(buffered, format="JPEG") | |
# Encode the image to base64 | |
image_data = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "You are a product designer. I've attached a moodboard here. In one sentence, what do all of these elements have in common? Answer from a design language perspective, if you were telling another designer to create something similar, including any repeating colors and materials and shapes and textures. This is for a single product, so respond as though you're applying them to a single object. Reply with a completion to the following (don't include these words please, just the rest): [A render of an object which] [your response]. Do NOT include 'A render of an object which' in your response."}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": "data:image/jpeg;base64," + image_data, | |
}, | |
}, | |
], | |
} | |
], | |
max_tokens=300, | |
) | |
return response.choices[0].message.content | |
except openai.BadRequestError as e: | |
print(e) | |
print("e type") | |
print(type(e)) | |
raise gr.Error(f"Please retry with a different moodboard file (below 20 MB in size and is of one the following formats: ['png', 'jpeg', 'gif', 'webp'])") | |
except Exception as e: | |
raise gr.Error("Unknown Error") | |
# Todo -- better prompt generator, add another LLM layer combining the user prompt and moodboard description (in the case of the jacket, 'high quality render of yellow jacket, its fabric is a pattern of cosmic etc etc' worked well) | |
# Could even do this 4 different times to get more diversity of renders | |
# Add "simple" to prompt before word | |
def image_classifier(moodboard, prompt): | |
if moodboard is not None: | |
pil_image = Image.fromarray(moodboard.astype('uint8')) | |
openai_response = call_openai(pil_image) | |
else: | |
raise gr.Error(f"Please upload a moodboard to control image generation style") | |
input = { | |
"prompt": "high quality render of a " + prompt + " which " + openai_response + ", minimalist and simple mockup on a white background", | |
"output_format": "jpg" | |
} | |
try: | |
output = replicate.run( | |
"stability-ai/stable-diffusion-3", | |
input=input | |
) | |
except Exception as e: | |
raise gr.Error(f"Error: {e}") | |
try: | |
image_url = output[0] | |
response = requests.get(image_url) | |
img1 = Image.open(io.BytesIO(response.content)) | |
except Exception as e: | |
raise gr.Error(f"Image download failed: {e}") | |
input["aspect_ratio"] = "3:2" | |
input["cfg"] = 6 | |
try: | |
output = replicate.run( | |
"stability-ai/stable-diffusion-3", | |
input=input | |
) | |
image_url = output[0] | |
response = requests.get(image_url) | |
img2 = Image.open(io.BytesIO(response.content)) | |
except Exception as e: | |
raise gr.Error(f"Second image download failed: {e}") | |
# Call SDXL API with the response from OpenAI | |
input = { | |
"width": 768, | |
"height": 768, | |
"prompt": "centered high quality render of a " + prompt + " which " + openai_response + ' centered on a plain white background', | |
"negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch, logo, buttons, markings, text, wires, complex, screws, nails, construction", | |
"refine": "expert_ensemble_refiner", | |
"apply_watermark": False, | |
"num_inference_steps": 25, | |
"num_outputs": 2, | |
"guidance_scale": 8.5 | |
} | |
output = replicate.run( | |
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", | |
input=input | |
) | |
images = [img1, img2] | |
for i in range(min(len(output), 2)): | |
image_url = output[i] | |
response = requests.get(image_url) | |
images.append(Image.open(io.BytesIO(response.content))) | |
# Add empty images if fewer than 3 were returned | |
while len(images) < 4: | |
images.append(Image.new('RGB', (768, 768), 'gray')) | |
images.reverse() | |
return images | |
demo = gr.Interface(fn=image_classifier, inputs=["image", "text"], outputs=["image", "image", "image", "image"]) | |
demo.launch(share=True) | |