Mairaaa commited on
Commit
40cbb76
·
verified ·
1 Parent(s): 5fba144

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -81
app.py CHANGED
@@ -1,97 +1,67 @@
1
- import os
2
- import numpy as np # Corrected import
3
- import torch
4
  import streamlit as st
 
5
  from PIL import Image
6
- from accelerate import Accelerator
7
  from diffusers import DDIMScheduler, AutoencoderKL
8
  from transformers import CLIPTextModel, CLIPTokenizer
9
- from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
10
- from src.utils.set_seeds import set_seed
11
- from src.utils.image_from_pipe import generate_images_from_mgd_pipe
12
- from src.datasets.dresscode import DressCodeDataset
13
-
14
- # Set environment variables
15
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
16
- os.environ["WANDB_START_METHOD"] = "thread"
17
 
18
- # Function to process inputs and run inference
19
- def run_inference(prompt, sketch_image=None, category="dresses", seed=1, mixed_precision="fp16"):
20
- accelerator = Accelerator(mixed_precision=mixed_precision)
21
- device = accelerator.device
22
-
23
- # Load models and datasets
24
  tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
25
  text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
26
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
27
- val_scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
28
-
29
- unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True)
30
-
31
- vae.requires_grad_(False)
32
- text_encoder.requires_grad_(False)
33
-
34
- if seed is not None:
35
- set_seed(seed)
36
 
37
- category = [category]
38
- test_dataset = DressCodeDataset(
39
- dataroot_path="assets\data\dresscode", # Replace with actual dataset path
40
- phase="test",
41
- category=category,
42
- size=(512, 384),
43
- )
44
-
45
- test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
46
-
47
- text_encoder.to(device)
48
- vae.to(device)
49
- unet.to(device).eval()
50
-
51
- if sketch_image is not None:
52
- sketch_tensor = (
53
- torch.tensor(np.array(sketch_image)).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
54
- )
55
-
56
- val_pipe = MGDPipeDisentangled(
57
  text_encoder=text_encoder,
58
  vae=vae,
59
- unet=unet,
60
  tokenizer=tokenizer,
61
- scheduler=val_scheduler,
62
  ).to(device)
 
63
 
64
- val_pipe.enable_attention_slicing()
65
 
66
- generated_images = generate_images_from_mgd_pipe(
67
- test_dataloader=test_dataloader,
68
- pipe=val_pipe,
69
- guidance_scale=7.5,
70
- seed=seed,
71
- sketch_image=sketch_tensor if sketch_image is not None else None,
72
- prompt=prompt,
73
- )
74
-
75
- return Image.fromarray((generated_images[0] * 255).astype("uint8"))
76
 
77
  # Streamlit UI
78
- st.title("Fashion Image Generator")
79
- st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.")
80
-
81
- uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"])
82
- prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns")
83
- category = st.text_input("Enter category (optional):", "dresses")
84
- seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=1)
85
- precision = st.selectbox("Select precision:", ["fp16", "fp32"])
86
-
87
- if uploaded_sketch is not None:
88
- sketch_image = Image.open(uploaded_sketch)
89
- st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True)
90
-
91
- if st.button("Generate Image"):
92
- with st.spinner("Generating image..."):
93
- try:
94
- result_image = run_inference(prompt, sketch_image, category, seed, precision)
95
- st.image(result_image, caption="Generated Image", use_column_width=True)
96
- except Exception as e:
97
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
  from PIL import Image
4
+ from io import BytesIO
5
  from diffusers import DDIMScheduler, AutoencoderKL
6
  from transformers import CLIPTextModel, CLIPTokenizer
7
+ from src.mgd_pipelines.mgd_pipe import MGDPipe
 
 
 
 
 
 
 
8
 
9
+ # Initialize the model and other components
10
+ @st.cache_resource
11
+ def load_model():
12
+ # Define your model loading logic
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae")
15
  tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
16
  text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
17
+ unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
18
+ scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
 
 
 
 
 
 
 
 
19
 
20
+ pipe = MGDPipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  text_encoder=text_encoder,
22
  vae=vae,
23
+ unet=unet.to(vae.dtype),
24
  tokenizer=tokenizer,
25
+ scheduler=scheduler,
26
  ).to(device)
27
+ return pipe
28
 
29
+ pipe = load_model()
30
 
31
+ def generate_images(pipe, text_input=None, sketch=None):
32
+ # Generate images from text or sketch or both
33
+ images = []
34
+ if text_input:
35
+ prompt = [text_input]
36
+ images.extend(pipe(prompt=prompt))
37
+ if sketch:
38
+ sketch_image = Image.open(sketch).convert("RGB")
39
+ images.extend(pipe(sketch=sketch_image))
40
+ return images
41
 
42
  # Streamlit UI
43
+ st.title("Sketch & Text-based Image Generation")
44
+ st.write("Generate images based on rough sketches, text input, or both.")
45
+
46
+ option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
47
+
48
+ if option in ["Sketch", "Both"]:
49
+ sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
50
+
51
+ if option in ["Text", "Both"]:
52
+ text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
53
+
54
+ if st.button("Generate"):
55
+ if option == "Sketch" and not sketch_file:
56
+ st.error("Please upload a sketch.")
57
+ elif option == "Text" and not text_input:
58
+ st.error("Please provide text input.")
59
+ else:
60
+ # Generate images based on user input
61
+ with st.spinner("Generating images..."):
62
+ sketches = BytesIO(sketch_file.read()) if sketch_file else None
63
+ images = generate_images(pipe, text_input=text_input, sketch=sketches)
64
+
65
+ # Display results
66
+ for i, img in enumerate(images):
67
+ st.image(img, caption=f"Generated Image {i+1}")