Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,9 @@ import numpy as np
|
|
3 |
import random
|
4 |
from diffusers import DiffusionPipeline
|
5 |
import torch
|
|
|
|
|
|
|
6 |
|
7 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
|
@@ -19,7 +22,7 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
19 |
MAX_IMAGE_SIZE = 1024
|
20 |
|
21 |
def infer(prompt_part1, color, dress_type, design, prompt_part5, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
|
22 |
-
prompt = f"{prompt_part1} {color} colored {dress_type} with {design} design, {prompt_part5}
|
23 |
|
24 |
if randomize_seed:
|
25 |
seed = random.randint(0, MAX_SEED)
|
@@ -35,13 +38,18 @@ def infer(prompt_part1, color, dress_type, design, prompt_part5, negative_prompt
|
|
35 |
height=height,
|
36 |
generator=generator
|
37 |
).images[0]
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
return
|
40 |
|
41 |
examples = [
|
42 |
-
"red, t-shirt, yellow stripes",
|
43 |
-
"blue, hoodie, minimalist",
|
44 |
-
"red,
|
45 |
]
|
46 |
|
47 |
css = """
|
@@ -172,7 +180,7 @@ with gr.Blocks(css=css) as demo:
|
|
172 |
|
173 |
gr.Examples(
|
174 |
examples=examples,
|
175 |
-
inputs=[prompt_part2]
|
176 |
)
|
177 |
|
178 |
run_button.click(
|
|
|
3 |
import random
|
4 |
from diffusers import DiffusionPipeline
|
5 |
import torch
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
+
from PIL import Image
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
|
|
|
22 |
MAX_IMAGE_SIZE = 1024
|
23 |
|
24 |
def infer(prompt_part1, color, dress_type, design, prompt_part5, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
|
25 |
+
prompt = f"{prompt_part1} {color} colored {dress_type} with {design} design, {prompt_part5}"
|
26 |
|
27 |
if randomize_seed:
|
28 |
seed = random.randint(0, MAX_SEED)
|
|
|
38 |
height=height,
|
39 |
generator=generator
|
40 |
).images[0]
|
41 |
+
|
42 |
+
# Convert the PIL image to base64
|
43 |
+
buffered = BytesIO()
|
44 |
+
image.save(buffered, format="PNG")
|
45 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
46 |
|
47 |
+
return img_str
|
48 |
|
49 |
examples = [
|
50 |
+
["red", "t-shirt", "yellow stripes"],
|
51 |
+
["blue", "hoodie", "minimalist"],
|
52 |
+
["red", "sweatshirt", "geometric design"],
|
53 |
]
|
54 |
|
55 |
css = """
|
|
|
180 |
|
181 |
gr.Examples(
|
182 |
examples=examples,
|
183 |
+
inputs=[prompt_part2, prompt_part3, prompt_part4]
|
184 |
)
|
185 |
|
186 |
run_button.click(
|