Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,62 +1,77 @@
|
|
1 |
-
import streamlit as st
|
2 |
import os
|
|
|
|
|
3 |
from PIL import Image
|
4 |
-
from
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
st.write("Upload a rough sketch, set parameters, and generate realistic garment images.")
|
9 |
|
10 |
-
#
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
with open(sketch_path, "wb") as f:
|
30 |
-
f.write(uploaded_file.getbuffer())
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
"dataset": "dresscode",
|
36 |
-
"dataset_path": dataset_path,
|
37 |
-
"output_dir": output_dir,
|
38 |
-
"guidance_scale": 7.5,
|
39 |
-
"guidance_scale_sketch": guidance_scale_sketch,
|
40 |
-
"mixed_precision": mixed_precision,
|
41 |
-
"batch_size": batch_size,
|
42 |
-
"seed": seed,
|
43 |
-
"save_name": "generated_image", # Output file name
|
44 |
-
}
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
output_path = main(args) # Call your backend main function
|
50 |
-
st.write("Image generation complete!")
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
else:
|
62 |
-
st.error("Please upload a sketch
|
|
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
+
import streamlit as st
|
4 |
from PIL import Image
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
7 |
+
from src.mgd_pipelines.mgd_pipe import MGDPipe
|
8 |
+
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
|
|
|
9 |
|
10 |
+
# Function to load models
|
11 |
+
def load_models(pretrained_model_name_or_path, device):
|
12 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
13 |
+
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
|
14 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
15 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
16 |
+
scheduler.set_timesteps(50, device=device)
|
17 |
+
|
18 |
+
unet = torch.hub.load(
|
19 |
+
repo_or_dir="aimagelab/multimodal-garment-designer",
|
20 |
+
model="mgd",
|
21 |
+
pretrained=True,
|
22 |
+
source="github",
|
23 |
+
)
|
24 |
+
return tokenizer, text_encoder, vae, scheduler, unet
|
25 |
|
26 |
+
# Function to generate images
|
27 |
+
def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
|
28 |
+
# Preprocess inputs
|
29 |
+
sketch = sketch.resize((512, 384)).convert("RGB")
|
30 |
+
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
|
31 |
+
|
32 |
+
# Tokenize prompt
|
33 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
34 |
+
|
35 |
+
# Initialize pipeline
|
36 |
+
pipeline = MGDPipe(
|
37 |
+
text_encoder=text_encoder.to(device),
|
38 |
+
vae=vae.to(device),
|
39 |
+
unet=unet.to(device),
|
40 |
+
tokenizer=tokenizer,
|
41 |
+
scheduler=scheduler,
|
42 |
+
).to(device)
|
43 |
|
44 |
+
# Generate image
|
45 |
+
pipeline.enable_attention_slicing()
|
46 |
+
with torch.inference_mode():
|
47 |
+
outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
|
48 |
+
|
49 |
+
return outputs[0]
|
|
|
|
|
50 |
|
51 |
+
# Streamlit UI
|
52 |
+
st.title("Garment Designer")
|
53 |
+
st.write("Upload a sketch and provide a text description to generate garment designs!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
# User Inputs
|
56 |
+
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
|
57 |
+
text_prompt = st.text_input("Enter a text description for the garment")
|
|
|
|
|
58 |
|
59 |
+
# Generate button
|
60 |
+
if st.button("Generate"):
|
61 |
+
if uploaded_file and text_prompt:
|
62 |
+
# Load models
|
63 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
+
pretrained_model_path = "your-pretrained-model-path" # Replace with actual model path
|
65 |
+
tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device)
|
66 |
+
|
67 |
+
# Load sketch
|
68 |
+
sketch = Image.open(uploaded_file)
|
69 |
+
|
70 |
+
# Generate image
|
71 |
+
st.write("Generating the garment design...")
|
72 |
+
output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device)
|
73 |
+
|
74 |
+
# Display output
|
75 |
+
st.image(output_image, caption="Generated Garment Design", use_column_width=True)
|
76 |
else:
|
77 |
+
st.error("Please upload a sketch and enter a text description.")
|