File size: 6,368 Bytes
3178eaa
 
 
 
 
 
 
 
 
 
374f8ce
 
3178eaa
 
374f8ce
3178eaa
 
374f8ce
 
 
3178eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374f8ce
 
 
 
 
 
 
10a713a
374f8ce
 
 
 
 
3178eaa
374f8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
3178eaa
26c0e0c
 
 
 
0affb4b
3178eaa
374f8ce
 
 
 
10a713a
374f8ce
 
0affb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374f8ce
 
10a713a
374f8ce
 
 
 
0affb4b
 
 
 
3178eaa
374f8ce
 
 
 
 
 
 
 
 
 
 
10a713a
3178eaa
266a587
 
10a713a
 
266a587
 
 
10a713a
 
3178eaa
0affb4b
 
 
 
3178eaa
 
266a587
3178eaa
 
374f8ce
0affb4b
374f8ce
 
 
 
 
3178eaa
374f8ce
0affb4b
374f8ce
 
 
 
 
0affb4b
374f8ce
3178eaa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import base64
import numpy as np
from PIL import Image
import io
import requests

import replicate
from flask import Flask, request
import gradio as gr

import openai
from openai import OpenAI


from dotenv import load_dotenv, find_dotenv

import json


# Locate the .env file
dotenv_path = find_dotenv()

load_dotenv(dotenv_path)

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN')



client = OpenAI()


def call_openai(pil_image):
    # Save the PIL image to a bytes buffer
    buffered = io.BytesIO()
    pil_image.save(buffered, format="JPEG")
    
    # Encode the image to base64
    image_data = base64.b64encode(buffered.getvalue()).decode('utf-8')
    
    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "You are a product designer. I've attached a moodboard here. In one sentence, what do all of these elements have in common? Answer from a design language perspective, if you were telling another designer to create something similar, including any repeating colors and materials and shapes and textures. This is for a single product, so respond as though you're applying them to a single object. Reply with a completion to the following (don't include these words please, just the rest): [A render of an object which] [your response]. Do NOT include 'A render of an object which' in your response."},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": "data:image/jpeg;base64," + image_data,
                            },
                        },
                    ],
                }
            ],
            max_tokens=300,
        )
        return response.choices[0].message.content
    except openai.BadRequestError as e:
        print(e)
        print("e type")
        print(type(e))
        raise gr.Error(f"Please retry with a different moodboard file (below 20 MB in size and is of one the following formats: ['png', 'jpeg', 'gif', 'webp'])")
    except Exception as e:
        raise gr.Error("Unknown Error")


# Todo -- better prompt generator, add another LLM layer combining the user prompt and moodboard description (in the case of the jacket, 'high quality render of yellow jacket, its fabric is a pattern of cosmic etc etc' worked well)
# Could even do this 4 different times to get more diversity of renders
# Add "simple" to prompt before word

def image_classifier(moodboard, starter_image, image_strength, prompt):

    if moodboard is not None:
        pil_image = Image.fromarray(moodboard.astype('uint8'))

        openai_response = call_openai(pil_image)

    else:
        raise gr.Error(f"Please upload a moodboard to control image generation style")
    
    if starter_image is not None:
        starter_image_pil = Image.fromarray(starter_image.astype('uint8'))

        # Resize the starter image if either dimension is larger than 768 pixels
        if starter_image_pil.size[0] > 768 or starter_image_pil.size[1] > 768:
            # Calculate the new size while maintaining the aspect ratio
            if starter_image_pil.size[0] > starter_image_pil.size[1]:
                # Width is larger than height
                new_width = 768
                new_height = int((768 / starter_image_pil.size[0]) * starter_image_pil.size[1])
            else:
                # Height is larger than width
                new_height = 768
                new_width = int((768 / starter_image_pil.size[1]) * starter_image_pil.size[0])
                
                # Resize the image
                starter_image_pil = starter_image_pil.resize((new_width, new_height), Image.LANCZOS)

        # Save the starter image to a bytes buffer
        buffered = io.BytesIO()
        starter_image_pil.save(buffered, format="JPEG")
        
        # Encode the starter image to base64
        starter_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')


    input = {
        "prompt": "high quality render of a " + prompt + " which " + openai_response + ", minimalist and simple mockup on a white background",
        "output_format": "jpg"
    }
    

    if starter_image is not None:
        input["image"] = "data:image/jpeg;base64," + starter_image_base64
        input["prompt_strength"] = 1-image_strength
    
    
    try:
        output = replicate.run(
            "stability-ai/stable-diffusion-3",
            input=input
        )
        image_url = output[0]
        response = requests.get(image_url)
        img2 = Image.open(io.BytesIO(response.content))
    except Exception as e:
        raise gr.Error(f"Second image download failed: {e}")

    # Call SDXL API with the response from OpenAI
    input = {
        "width": 768,
        "height": 768,
        "prompt": "centered high quality render of a " + prompt + " which " + openai_response + ' centered on a plain white background',
        "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch, logo, buttons, markings, text, wires, complex, screws, nails, construction",
        "refine": "expert_ensemble_refiner",
        "apply_watermark": False,
        "num_inference_steps": 25,
        "num_outputs": 2,
        "guidance_scale": 8.5
    }

    if starter_image is not None:
        input["image"] = "data:image/jpeg;base64," + starter_image_base64
        input["prompt_strength"] = 1-image_strength
    
    output = replicate.run(
        "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
        input=input
    )

    images = [img2]

    for i in range(min(len(output), 2)):
        image_url = output[i]
        response = requests.get(image_url)
        images.append(Image.open(io.BytesIO(response.content)))
    
    # Add empty images if fewer than 3 were returned
    while len(images) < 3:
        images.append(Image.new('RGB', (768, 768), 'gray'))

    images.reverse()
    return images

demo = gr.Interface(fn=image_classifier, inputs=["image", "image", gr.Slider(0, 1, step=0.05, value=0.2), "text"], outputs=["image", "image", "image"])
demo.launch(share=True)