Mairaaa commited on
Commit
c5a0203
·
verified ·
1 Parent(s): 4198ed7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -49
app.py CHANGED
@@ -1,77 +1,106 @@
1
- import os
2
- import torch
3
  import streamlit as st
 
4
  from PIL import Image
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, DDIMScheduler
7
- from src.mgd_pipelines.mgd_pipe import MGDPipe
8
- from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
9
-
10
- # Function to load models
11
- def load_models(pretrained_model_name_or_path, device):
12
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
13
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
14
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
15
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
16
- scheduler.set_timesteps(50, device=device)
17
-
 
 
 
 
 
 
 
 
 
 
 
18
  unet = torch.hub.load(
19
  repo_or_dir="aimagelab/multimodal-garment-designer",
 
20
  model="mgd",
21
  pretrained=True,
22
- source="github",
23
- )
24
- return tokenizer, text_encoder, vae, scheduler, unet
25
 
26
- # Function to generate images
27
- def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
28
- # Preprocess inputs
29
- sketch = sketch.resize((512, 384)).convert("RGB")
30
- sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
31
-
32
- # Tokenize prompt
33
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
34
-
35
- # Initialize pipeline
36
  pipeline = MGDPipe(
37
- text_encoder=text_encoder.to(device),
38
- vae=vae.to(device),
39
- unet=unet.to(device),
40
  tokenizer=tokenizer,
 
41
  scheduler=scheduler,
42
- ).to(device)
43
 
44
- # Generate image
45
- pipeline.enable_attention_slicing()
46
- with torch.inference_mode():
47
- outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
48
-
49
- return outputs[0]
50
 
51
- # Streamlit UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  st.title("Garment Designer")
53
  st.write("Upload a sketch and provide a text description to generate garment designs!")
54
 
55
- # User Inputs
56
  uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
57
  text_prompt = st.text_input("Enter a text description for the garment")
58
 
59
- # Generate button
60
  if st.button("Generate"):
61
  if uploaded_file and text_prompt:
62
- # Load models
 
 
63
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
- pretrained_model_path = "your-pretrained-model-path" # Replace with actual model path
65
- tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device)
66
-
67
  # Load sketch
68
  sketch = Image.open(uploaded_file)
69
-
70
- # Generate image
71
  st.write("Generating the garment design...")
72
- output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device)
73
-
74
- # Display output
75
- st.image(output_image, caption="Generated Garment Design", use_column_width=True)
76
  else:
77
  st.error("Please upload a sketch and enter a text description.")
 
 
 
1
  import streamlit as st
2
+ import torch
3
  from PIL import Image
4
  from transformers import CLIPTextModel, CLIPTokenizer
5
  from diffusers import AutoencoderKL, DDIMScheduler
6
+ from src.mgd_pipelines.mgd_pipe import MGDPipe # Use your implementation of MGDPipe
7
+
8
+
9
+ # Load models and pipeline
10
+ def load_models(pretrained_model_path, device):
11
+ """
12
+ Load the models required for the MGDPipe.
13
+ Args:
14
+ pretrained_model_path (str): Path or Hugging Face identifier for the model.
15
+ device (torch.device): Device to load the models on.
16
+
17
+ Returns:
18
+ MGDPipe: Initialized MGDPipe object.
19
+ """
20
+ # Load components of Stable Diffusion
21
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
22
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
23
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
24
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
25
+ scheduler.set_timesteps(50)
26
+
27
+ # Load the UNet model
28
  unet = torch.hub.load(
29
  repo_or_dir="aimagelab/multimodal-garment-designer",
30
+ source="github",
31
  model="mgd",
32
  pretrained=True,
33
+ dataset="dresscode", # Change to "vitonhd" if needed
34
+ ).to(device)
 
35
 
36
+ # Initialize the pipeline
 
 
 
 
 
 
 
 
 
37
  pipeline = MGDPipe(
38
+ vae=vae,
39
+ text_encoder=text_encoder,
 
40
  tokenizer=tokenizer,
41
+ unet=unet,
42
  scheduler=scheduler,
43
+ )
44
 
45
+ return pipeline
 
 
 
 
 
46
 
47
+
48
+ # Function to preprocess and generate images
49
+ def generate_image(pipeline, sketch, prompt, device):
50
+ """
51
+ Generate an image using the MGDPipe.
52
+ Args:
53
+ pipeline (MGDPipe): Initialized MGDPipe object.
54
+ sketch (PIL.Image.Image): Sketch uploaded by the user.
55
+ prompt (str): Text prompt provided by the user.
56
+ device (torch.device): Device for inference.
57
+
58
+ Returns:
59
+ PIL.Image.Image: Generated image.
60
+ """
61
+ # Preprocess the sketch
62
+ sketch = sketch.resize((512, 384)).convert("RGB")
63
+ sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
64
+
65
+ # Run the pipeline
66
+ output = pipeline(
67
+ prompt=prompt,
68
+ image=torch.zeros_like(sketch_tensor), # Placeholder for masked image
69
+ mask_image=torch.ones_like(sketch_tensor), # Placeholder for mask
70
+ pose_map=torch.zeros((1, 3, 64, 48)).to(device), # Placeholder pose map
71
+ sketch=sketch_tensor,
72
+ guidance_scale=7.5,
73
+ num_inference_steps=50,
74
+ )
75
+
76
+ return output.images[0]
77
+
78
+
79
+ # Streamlit Interface
80
  st.title("Garment Designer")
81
  st.write("Upload a sketch and provide a text description to generate garment designs!")
82
 
83
+ # User inputs
84
  uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
85
  text_prompt = st.text_input("Enter a text description for the garment")
86
 
 
87
  if st.button("Generate"):
88
  if uploaded_file and text_prompt:
89
+ st.write("Loading models...")
90
+
91
+ # Load the pipeline
92
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ pretrained_model_path = "runwayml/stable-diffusion-inpainting" # Change as required
94
+ pipeline = load_models(pretrained_model_path, device)
95
+
96
  # Load sketch
97
  sketch = Image.open(uploaded_file)
98
+
99
+ # Generate the image
100
  st.write("Generating the garment design...")
101
+ generated_image = generate_image(pipeline, sketch, text_prompt, device)
102
+
103
+ # Display the result
104
+ st.image(generated_image, caption="Generated Garment Design", use_column_width=True)
105
  else:
106
  st.error("Please upload a sketch and enter a text description.")