Mairaaa commited on
Commit
5fba144
·
verified ·
1 Parent(s): e22ce03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
- import pandas as np
3
  import torch
4
  import streamlit as st
5
  from PIL import Image
6
  from accelerate import Accelerator
7
  from diffusers import DDIMScheduler, AutoencoderKL
8
  from transformers import CLIPTextModel, CLIPTokenizer
9
- from src.mgd_pipelines.mgd_pipe import MGDPipe
10
  from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
11
  from src.utils.set_seeds import set_seed
12
  from src.utils.image_from_pipe import generate_images_from_mgd_pipe
@@ -17,8 +16,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
17
  os.environ["WANDB_START_METHOD"] = "thread"
18
 
19
  # Function to process inputs and run inference
20
- def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"):
21
- # Initialize accelerator
22
  accelerator = Accelerator(mixed_precision=mixed_precision)
23
  device = accelerator.device
24
 
@@ -26,39 +24,35 @@ def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixe
26
  tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
27
  text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
28
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
29
- val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler")
30
 
31
- # Load UNet (assumed pretrained)
32
  unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True)
33
 
34
- # Freeze VAE and text encoder
35
  vae.requires_grad_(False)
36
  text_encoder.requires_grad_(False)
37
 
38
- # Set seed for reproducibility
39
  if seed is not None:
40
  set_seed(seed)
41
 
42
- # Load appropriate dataset
43
  category = [category]
44
  test_dataset = DressCodeDataset(
45
- dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384)
 
 
 
46
  )
47
 
48
  test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
49
 
50
- # Move models to the device
51
  text_encoder.to(device)
52
  vae.to(device)
53
  unet.to(device).eval()
54
 
55
- # Handle sketch and text inputs
56
  if sketch_image is not None:
57
- # Process the sketch (resize, normalize, etc.)
58
- sketch_image = sketch_image.resize((512, 384))
59
- sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device)
60
 
61
- # Select pipeline (disentangled if required)
62
  val_pipe = MGDPipeDisentangled(
63
  text_encoder=text_encoder,
64
  vae=vae,
@@ -69,41 +63,35 @@ def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixe
69
 
70
  val_pipe.enable_attention_slicing()
71
 
72
- # Generate image
73
  generated_images = generate_images_from_mgd_pipe(
74
  test_dataloader=test_dataloader,
75
  pipe=val_pipe,
76
  guidance_scale=7.5,
77
  seed=seed,
78
  sketch_image=sketch_tensor if sketch_image is not None else None,
79
- prompt=prompt
80
  )
81
 
82
- return generated_images[0] # Assuming single image output
83
 
84
  # Streamlit UI
85
  st.title("Fashion Image Generator")
86
  st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.")
87
 
88
- # Upload a sketch image
89
  uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"])
90
-
91
- # Text input for prompt
92
  prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns")
93
-
94
- # Input options
95
  category = st.text_input("Enter category (optional):", "dresses")
96
- seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None)
97
  precision = st.selectbox("Select precision:", ["fp16", "fp32"])
98
 
99
- # Show uploaded sketch image
100
  if uploaded_sketch is not None:
101
  sketch_image = Image.open(uploaded_sketch)
102
  st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True)
103
 
104
- # Button to generate image
105
  if st.button("Generate Image"):
106
  with st.spinner("Generating image..."):
107
- # Run inference with sketch or prompt (or both)
108
- result_image = run_inference(prompt, sketch_image, category, seed, precision)
109
- st.image(result_image, caption="Generated Image", use_column_width=True)
 
 
 
1
  import os
2
+ import numpy as np # Corrected import
3
  import torch
4
  import streamlit as st
5
  from PIL import Image
6
  from accelerate import Accelerator
7
  from diffusers import DDIMScheduler, AutoencoderKL
8
  from transformers import CLIPTextModel, CLIPTokenizer
 
9
  from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
10
  from src.utils.set_seeds import set_seed
11
  from src.utils.image_from_pipe import generate_images_from_mgd_pipe
 
16
  os.environ["WANDB_START_METHOD"] = "thread"
17
 
18
  # Function to process inputs and run inference
19
+ def run_inference(prompt, sketch_image=None, category="dresses", seed=1, mixed_precision="fp16"):
 
20
  accelerator = Accelerator(mixed_precision=mixed_precision)
21
  device = accelerator.device
22
 
 
24
  tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
25
  text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
26
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
27
+ val_scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
28
 
 
29
  unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True)
30
 
 
31
  vae.requires_grad_(False)
32
  text_encoder.requires_grad_(False)
33
 
 
34
  if seed is not None:
35
  set_seed(seed)
36
 
 
37
  category = [category]
38
  test_dataset = DressCodeDataset(
39
+ dataroot_path="assets\data\dresscode", # Replace with actual dataset path
40
+ phase="test",
41
+ category=category,
42
+ size=(512, 384),
43
  )
44
 
45
  test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
46
 
 
47
  text_encoder.to(device)
48
  vae.to(device)
49
  unet.to(device).eval()
50
 
 
51
  if sketch_image is not None:
52
+ sketch_tensor = (
53
+ torch.tensor(np.array(sketch_image)).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
54
+ )
55
 
 
56
  val_pipe = MGDPipeDisentangled(
57
  text_encoder=text_encoder,
58
  vae=vae,
 
63
 
64
  val_pipe.enable_attention_slicing()
65
 
 
66
  generated_images = generate_images_from_mgd_pipe(
67
  test_dataloader=test_dataloader,
68
  pipe=val_pipe,
69
  guidance_scale=7.5,
70
  seed=seed,
71
  sketch_image=sketch_tensor if sketch_image is not None else None,
72
+ prompt=prompt,
73
  )
74
 
75
+ return Image.fromarray((generated_images[0] * 255).astype("uint8"))
76
 
77
  # Streamlit UI
78
  st.title("Fashion Image Generator")
79
  st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.")
80
 
 
81
  uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"])
 
 
82
  prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns")
 
 
83
  category = st.text_input("Enter category (optional):", "dresses")
84
+ seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=1)
85
  precision = st.selectbox("Select precision:", ["fp16", "fp32"])
86
 
 
87
  if uploaded_sketch is not None:
88
  sketch_image = Image.open(uploaded_sketch)
89
  st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True)
90
 
 
91
  if st.button("Generate Image"):
92
  with st.spinner("Generating image..."):
93
+ try:
94
+ result_image = run_inference(prompt, sketch_image, category, seed, precision)
95
+ st.image(result_image, caption="Generated Image", use_column_width=True)
96
+ except Exception as e:
97
+ st.error(f"An error occurred: {e}")