Mairaaa commited on
Commit
831b686
·
verified ·
1 Parent(s): 0547d6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -69
app.py CHANGED
@@ -1,95 +1,109 @@
1
  import os
 
2
  import torch
3
  import streamlit as st
4
- from diffusers import AutoencoderKL, DDIMScheduler
 
 
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from src.mgd_pipelines.mgd_pipe import MGDPipe
7
  from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
8
- from src.utils.image_from_pipe import generate_images_from_mgd_pipe
9
- from accelerate import Accelerator
10
- from diffusers.utils import check_min_version
11
  from src.utils.set_seeds import set_seed
 
 
12
 
13
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
14
- check_min_version("0.10.0.dev0")
15
-
16
- # Set the environment variables for Hugging Face Spaces
17
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
18
  os.environ["WANDB_START_METHOD"] = "thread"
19
 
20
- # Streamlit interface components
21
- st.title("Fashion Image Generation with Multimodal Garment Designer")
22
-
23
- # Streamlit Input Parameters
24
- category = st.selectbox("Select Category", ["dresses", "upper_body", "lower_body", "all"])
25
- guidance_scale = st.slider("Guidance Scale", min_value=0.1, max_value=20.0, value=7.5, step=0.1)
26
- guidance_scale_pose = st.slider("Guidance Scale (Pose)", min_value=0.1, max_value=20.0, value=7.5, step=0.1)
27
- guidance_scale_sketch = st.slider("Guidance Scale (Sketch)", min_value=0.1, max_value=20.0, value=7.5, step=0.1)
28
- sketch_cond_rate = st.slider("Sketch Conditioning Rate", min_value=0.1, max_value=1.0, value=0.5, step=0.05)
29
- start_cond_rate = st.slider("Start Conditioning Rate", min_value=0.1, max_value=1.0, value=0.5, step=0.05)
30
- seed = st.number_input("Seed", value=42, min_value=1)
31
-
32
- # Button to run the image generation
33
- if st.button("Generate Image"):
34
- # Initialize Accelerator (for mixed precision, etc.)
35
- accelerator = Accelerator()
36
  device = accelerator.device
37
 
38
- # Set the seed
39
- set_seed(seed)
 
 
 
40
 
41
- # Model and Tokenizer loading (use pre-trained from Hugging Face)
42
- model_name = "stabilityai/stable-diffusion-2-1-base" # Use appropriate model name
43
 
44
- # Load scheduler, tokenizer, and models
45
- val_scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
46
- val_scheduler.set_timesteps(50, device=device)
47
 
48
- tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
49
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
50
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
51
 
52
- # Load UNet model (you can use your own model)
53
- unet = torch.hub.load(
54
- dataset="aimagelab/multimodal-garment-designer",
55
- repo_or_dir="aimagelab/multimodal-garment-designer",
56
- source="github",
57
- model="mgd",
58
- pretrained=True,
59
  )
60
 
61
- # Freeze VAE and text encoder
62
- vae.requires_grad_(False)
63
- text_encoder.requires_grad_(False)
 
 
 
64
 
65
- # Select pipeline (use disentangled option if needed)
66
- val_pipe = MGDPipe(
 
 
 
 
 
 
67
  text_encoder=text_encoder,
68
  vae=vae,
69
- unet=unet.to(vae.dtype),
70
  tokenizer=tokenizer,
71
  scheduler=val_scheduler,
72
  ).to(device)
73
 
74
- # Run image generation using your pipeline
75
- with torch.no_grad():
76
- # Generate the image
77
- images = generate_images_from_mgd_pipe(
78
- test_order="test", # or some predefined order
79
- pipe=val_pipe,
80
- test_dataloader=None, # Adjust accordingly, or use pre-existing dataset
81
- save_name="generated_image",
82
- dataset="dresscode", # Adjust if needed
83
- output_dir=".", # Save location
84
- guidance_scale=guidance_scale,
85
- guidance_scale_pose=guidance_scale_pose,
86
- guidance_scale_sketch=guidance_scale_sketch,
87
- sketch_cond_rate=sketch_cond_rate,
88
- start_cond_rate=start_cond_rate,
89
- no_pose=False,
90
- disentagle=False, # Adjust if needed
91
- seed=seed,
92
- )
93
-
94
- # Display the generated image
95
- st.image(images[0], caption="Generated Fashion Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import pandas as np
3
  import torch
4
  import streamlit as st
5
+ from PIL import Image
6
+ from accelerate import Accelerator
7
+ from diffusers import DDIMScheduler, AutoencoderKL
8
  from transformers import CLIPTextModel, CLIPTokenizer
9
  from src.mgd_pipelines.mgd_pipe import MGDPipe
10
  from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
 
 
 
11
  from src.utils.set_seeds import set_seed
12
+ from src.utils.image_from_pipe import generate_images_from_mgd_pipe
13
+ from datasets.dresscode import DressCodeDataset
14
 
15
+ # Set environment variables
 
 
 
16
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
17
  os.environ["WANDB_START_METHOD"] = "thread"
18
 
19
+ # Function to process inputs and run inference
20
+ def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"):
21
+ # Initialize accelerator
22
+ accelerator = Accelerator(mixed_precision=mixed_precision)
 
 
 
 
 
 
 
 
 
 
 
 
23
  device = accelerator.device
24
 
25
+ # Load models and datasets
26
+ tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
27
+ text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
28
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
29
+ val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler")
30
 
31
+ # Load UNet (assumed pretrained)
32
+ unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True)
33
 
34
+ # Freeze VAE and text encoder
35
+ vae.requires_grad_(False)
36
+ text_encoder.requires_grad_(False)
37
 
38
+ # Set seed for reproducibility
39
+ if seed is not None:
40
+ set_seed(seed)
41
 
42
+ # Load appropriate dataset
43
+ category = [category]
44
+ test_dataset = DressCodeDataset(
45
+ dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384)
 
 
 
46
  )
47
 
48
+ test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
49
+
50
+ # Move models to the device
51
+ text_encoder.to(device)
52
+ vae.to(device)
53
+ unet.to(device).eval()
54
 
55
+ # Handle sketch and text inputs
56
+ if sketch_image is not None:
57
+ # Process the sketch (resize, normalize, etc.)
58
+ sketch_image = sketch_image.resize((512, 384))
59
+ sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device)
60
+
61
+ # Select pipeline (disentangled if required)
62
+ val_pipe = MGDPipeDisentangled(
63
  text_encoder=text_encoder,
64
  vae=vae,
65
+ unet=unet,
66
  tokenizer=tokenizer,
67
  scheduler=val_scheduler,
68
  ).to(device)
69
 
70
+ val_pipe.enable_attention_slicing()
71
+
72
+ # Generate image
73
+ generated_images = generate_images_from_mgd_pipe(
74
+ test_dataloader=test_dataloader,
75
+ pipe=val_pipe,
76
+ guidance_scale=7.5,
77
+ seed=seed,
78
+ sketch_image=sketch_tensor if sketch_image is not None else None,
79
+ prompt=prompt
80
+ )
81
+
82
+ return generated_images[0] # Assuming single image output
83
+
84
+ # Streamlit UI
85
+ st.title("Fashion Image Generator")
86
+ st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.")
87
+
88
+ # Upload a sketch image
89
+ uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"])
90
+
91
+ # Text input for prompt
92
+ prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns")
93
+
94
+ # Input options
95
+ category = st.text_input("Enter category (optional):", "dresses")
96
+ seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None)
97
+ precision = st.selectbox("Select precision:", ["fp16", "fp32"])
98
+
99
+ # Show uploaded sketch image
100
+ if uploaded_sketch is not None:
101
+ sketch_image = Image.open(uploaded_sketch)
102
+ st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True)
103
+
104
+ # Button to generate image
105
+ if st.button("Generate Image"):
106
+ with st.spinner("Generating image..."):
107
+ # Run inference with sketch or prompt (or both)
108
+ result_image = run_inference(prompt, sketch_image, category, seed, precision)
109
+ st.image(result_image, caption="Generated Image", use_column_width=True)