Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,77 +1,106 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
import streamlit as st
|
|
|
4 |
from PIL import Image
|
5 |
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
from diffusers import AutoencoderKL, DDIMScheduler
|
7 |
-
from src.mgd_pipelines.mgd_pipe import MGDPipe
|
8 |
-
|
9 |
-
|
10 |
-
#
|
11 |
-
def load_models(
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
unet = torch.hub.load(
|
19 |
repo_or_dir="aimagelab/multimodal-garment-designer",
|
|
|
20 |
model="mgd",
|
21 |
pretrained=True,
|
22 |
-
|
23 |
-
)
|
24 |
-
return tokenizer, text_encoder, vae, scheduler, unet
|
25 |
|
26 |
-
#
|
27 |
-
def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
|
28 |
-
# Preprocess inputs
|
29 |
-
sketch = sketch.resize((512, 384)).convert("RGB")
|
30 |
-
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
|
31 |
-
|
32 |
-
# Tokenize prompt
|
33 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
34 |
-
|
35 |
-
# Initialize pipeline
|
36 |
pipeline = MGDPipe(
|
37 |
-
|
38 |
-
|
39 |
-
unet=unet.to(device),
|
40 |
tokenizer=tokenizer,
|
|
|
41 |
scheduler=scheduler,
|
42 |
-
)
|
43 |
|
44 |
-
|
45 |
-
pipeline.enable_attention_slicing()
|
46 |
-
with torch.inference_mode():
|
47 |
-
outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
|
48 |
-
|
49 |
-
return outputs[0]
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
st.title("Garment Designer")
|
53 |
st.write("Upload a sketch and provide a text description to generate garment designs!")
|
54 |
|
55 |
-
# User
|
56 |
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
|
57 |
text_prompt = st.text_input("Enter a text description for the garment")
|
58 |
|
59 |
-
# Generate button
|
60 |
if st.button("Generate"):
|
61 |
if uploaded_file and text_prompt:
|
62 |
-
|
|
|
|
|
63 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
-
pretrained_model_path = "
|
65 |
-
|
66 |
-
|
67 |
# Load sketch
|
68 |
sketch = Image.open(uploaded_file)
|
69 |
-
|
70 |
-
# Generate image
|
71 |
st.write("Generating the garment design...")
|
72 |
-
|
73 |
-
|
74 |
-
# Display
|
75 |
-
st.image(
|
76 |
else:
|
77 |
st.error("Please upload a sketch and enter a text description.")
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
from PIL import Image
|
4 |
from transformers import CLIPTextModel, CLIPTokenizer
|
5 |
from diffusers import AutoencoderKL, DDIMScheduler
|
6 |
+
from src.mgd_pipelines.mgd_pipe import MGDPipe # Use your implementation of MGDPipe
|
7 |
+
|
8 |
+
|
9 |
+
# Load models and pipeline
|
10 |
+
def load_models(pretrained_model_path, device):
|
11 |
+
"""
|
12 |
+
Load the models required for the MGDPipe.
|
13 |
+
Args:
|
14 |
+
pretrained_model_path (str): Path or Hugging Face identifier for the model.
|
15 |
+
device (torch.device): Device to load the models on.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
MGDPipe: Initialized MGDPipe object.
|
19 |
+
"""
|
20 |
+
# Load components of Stable Diffusion
|
21 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
|
22 |
+
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
|
23 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
|
24 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
|
25 |
+
scheduler.set_timesteps(50)
|
26 |
+
|
27 |
+
# Load the UNet model
|
28 |
unet = torch.hub.load(
|
29 |
repo_or_dir="aimagelab/multimodal-garment-designer",
|
30 |
+
source="github",
|
31 |
model="mgd",
|
32 |
pretrained=True,
|
33 |
+
dataset="dresscode", # Change to "vitonhd" if needed
|
34 |
+
).to(device)
|
|
|
35 |
|
36 |
+
# Initialize the pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
pipeline = MGDPipe(
|
38 |
+
vae=vae,
|
39 |
+
text_encoder=text_encoder,
|
|
|
40 |
tokenizer=tokenizer,
|
41 |
+
unet=unet,
|
42 |
scheduler=scheduler,
|
43 |
+
)
|
44 |
|
45 |
+
return pipeline
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
|
48 |
+
# Function to preprocess and generate images
|
49 |
+
def generate_image(pipeline, sketch, prompt, device):
|
50 |
+
"""
|
51 |
+
Generate an image using the MGDPipe.
|
52 |
+
Args:
|
53 |
+
pipeline (MGDPipe): Initialized MGDPipe object.
|
54 |
+
sketch (PIL.Image.Image): Sketch uploaded by the user.
|
55 |
+
prompt (str): Text prompt provided by the user.
|
56 |
+
device (torch.device): Device for inference.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
PIL.Image.Image: Generated image.
|
60 |
+
"""
|
61 |
+
# Preprocess the sketch
|
62 |
+
sketch = sketch.resize((512, 384)).convert("RGB")
|
63 |
+
sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
|
64 |
+
|
65 |
+
# Run the pipeline
|
66 |
+
output = pipeline(
|
67 |
+
prompt=prompt,
|
68 |
+
image=torch.zeros_like(sketch_tensor), # Placeholder for masked image
|
69 |
+
mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask
|
70 |
+
pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map
|
71 |
+
sketch=sketch_tensor,
|
72 |
+
guidance_scale=7.5,
|
73 |
+
num_inference_steps=50,
|
74 |
+
)
|
75 |
+
|
76 |
+
return output.images[0]
|
77 |
+
|
78 |
+
|
79 |
+
# Streamlit Interface
|
80 |
st.title("Garment Designer")
|
81 |
st.write("Upload a sketch and provide a text description to generate garment designs!")
|
82 |
|
83 |
+
# User inputs
|
84 |
uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
|
85 |
text_prompt = st.text_input("Enter a text description for the garment")
|
86 |
|
|
|
87 |
if st.button("Generate"):
|
88 |
if uploaded_file and text_prompt:
|
89 |
+
st.write("Loading models...")
|
90 |
+
|
91 |
+
# Load the pipeline
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
+
pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required
|
94 |
+
pipeline = load_models(pretrained_model_path, device)
|
95 |
+
|
96 |
# Load sketch
|
97 |
sketch = Image.open(uploaded_file)
|
98 |
+
|
99 |
+
# Generate the image
|
100 |
st.write("Generating the garment design...")
|
101 |
+
generated_image = generate_image(pipeline, sketch, text_prompt, device)
|
102 |
+
|
103 |
+
# Display the result
|
104 |
+
st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
|
105 |
else:
|
106 |
st.error("Please upload a sketch and enter a text description.")
|