Mairaaa commited on
Commit
8689a3c
·
verified ·
1 Parent(s): d308227

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -124
app.py CHANGED
@@ -1,127 +1,62 @@
1
  import streamlit as st
2
- import torch
3
  from PIL import Image
4
- from io import BytesIO
5
- from diffusers import DDIMScheduler, AutoencoderKL
6
- from transformers import CLIPTextModel, CLIPTokenizer
7
- from src.mgd_pipelines.mgd_pipe import MGDPipe
8
-
9
- # Initialize the model and other components
10
- @st.cache_resource
11
- def load_model():
12
- try:
13
- # Define your model loading logic
14
- print("Initializing model loading...")
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- print(f"Device selected: {device}")
17
-
18
- # Load the VAE
19
- print("Loading VAE...")
20
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
21
- print("VAE loaded successfully.")
22
-
23
- # Load the tokenizer
24
- print("Loading tokenizer...")
25
- tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
26
- print("Tokenizer loaded successfully.")
27
-
28
- # Load the text encoder
29
- print("Loading text encoder...")
30
- text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
31
- print("Text encoder loaded successfully.")
32
-
33
- # Load the UNet model
34
- print("Loading UNet...")
35
- unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
36
- print("UNet loaded successfully.")
37
-
38
- # Load the scheduler
39
- print("Loading scheduler...")
40
- scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
41
- print("Scheduler loaded successfully.")
42
-
43
- # Initialize the pipeline
44
- print("Initializing pipeline...")
45
- pipe = MGDPipe(
46
- text_encoder=text_encoder,
47
- vae=vae,
48
- unet=unet.to(vae.dtype),
49
- tokenizer=tokenizer,
50
- scheduler=scheduler,
51
- ).to(device)
52
- pipe.enable_attention_slicing()
53
- print("Pipeline initialized successfully.")
54
- return pipe
55
- except Exception as e:
56
- print(f"Error loading the model: {e}")
57
- return None
58
-
59
- pipe = load_model()
60
-
61
- def generate_images(pipe, text_input=None, sketch=None):
62
- # Generate images from text or sketch or both
63
- images = []
64
- try:
65
- if pipe:
66
- # Generate from text
67
- if text_input:
68
- print(f"Generating image from text: {text_input}")
69
- images.append(pipe(prompt=[text_input]))
70
-
71
- # Generate from sketch
72
- if sketch:
73
- print("Generating image from sketch.")
74
- sketch_image = Image.open(sketch).convert("RGB")
75
- images.append(pipe(sketch=sketch_image))
76
- except Exception as e:
77
- print(f"Error during image generation: {e}")
78
- return images
79
-
80
- # Streamlit UI
81
- st.title("Sketch & Text-based Image Generation")
82
- st.write("Generate images based on rough sketches, text input, or both.")
83
-
84
- # Input options
85
- option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
86
-
87
- sketch_file = None
88
- text_input = None
89
-
90
- # Get sketch input
91
- if option in ["Sketch", "Both"]:
92
- sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
93
-
94
- # Get text input
95
- if option in ["Text", "Both"]:
96
- text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
97
-
98
- # Generate button
99
- if st.button("Generate"):
100
- # Ensure the model is loaded
101
- if pipe is None:
102
- st.error("Model failed to load. Please restart the application.")
103
- st.stop()
104
-
105
- # Validate inputs
106
- sketches = BytesIO(sketch_file.read()) if sketch_file else None
107
-
108
- if option == "Sketch" and not sketch_file:
109
- st.error("Please upload a sketch.")
110
- elif option == "Text" and not text_input:
111
- st.error("Please provide text input.")
112
- elif option == "Both" and not (sketch_file or text_input):
113
- st.error("Please provide both a sketch and a text prompt.")
114
  else:
115
- # Generate images
116
- with st.spinner("Generating images..."):
117
- images = generate_images(pipe, text_input=text_input, sketch=sketches)
118
-
119
- # Display results
120
- if images:
121
- for i, img in enumerate(images):
122
- if isinstance(img, torch.Tensor): # Convert tensor to image
123
- img = img.squeeze().permute(1, 2, 0).cpu().numpy()
124
- img = Image.fromarray((img * 255).astype("uint8"))
125
- st.image(img, caption=f"Generated Image {i+1}")
126
- else:
127
- st.error("Failed to generate images. Please check the inputs or model configuration.")
 
1
  import streamlit as st
2
+ import os
3
  from PIL import Image
4
+ from evl import main # Import the modified main function from evl.py
5
+
6
+ # Title and Description
7
+ st.title("Fashion Image Generator")
8
+ st.write("Upload a rough sketch, set parameters, and generate realistic garment images.")
9
+
10
+ # File Upload Section
11
+ uploaded_file = st.file_uploader("Upload your rough sketch (PNG, JPG, JPEG):", type=["png", "jpg", "jpeg"])
12
+
13
+ # Sidebar for Parameters
14
+ st.sidebar.title("Model Configuration")
15
+ pretrained_model_path = st.sidebar.text_input("Pretrained Model Path", "./models")
16
+ dataset_path = st.sidebar.text_input("Dataset Path", "./datasets/dresscode")
17
+ output_dir = st.sidebar.text_input("Output Directory", "./outputs")
18
+ guidance_scale_sketch = st.sidebar.slider("Sketch Guidance Scale", 1.0, 10.0, 7.5)
19
+ batch_size = st.sidebar.number_input("Batch Size", min_value=1, max_value=16, value=1)
20
+ mixed_precision = st.sidebar.selectbox("Mixed Precision Mode", ["fp16", "fp32"], index=0)
21
+ seed = st.sidebar.number_input("Random Seed", value=42, step=1)
22
+
23
+ # Run Button
24
+ if st.button("Generate Image"):
25
+ if uploaded_file:
26
+ # Save uploaded sketch locally
27
+ os.makedirs("temp_uploads", exist_ok=True)
28
+ sketch_path = os.path.join("temp_uploads", uploaded_file.name)
29
+ with open(sketch_path, "wb") as f:
30
+ f.write(uploaded_file.getbuffer())
31
+
32
+ # Prepare arguments for the backend
33
+ args = {
34
+ "pretrained_model_name_or_path": pretrained_model_path,
35
+ "dataset": "dresscode",
36
+ "dataset_path": dataset_path,
37
+ "output_dir": output_dir,
38
+ "guidance_scale": 7.5,
39
+ "guidance_scale_sketch": guidance_scale_sketch,
40
+ "mixed_precision": mixed_precision,
41
+ "batch_size": batch_size,
42
+ "seed": seed,
43
+ "save_name": "generated_image", # Output file name
44
+ }
45
+
46
+ # Run the backend model
47
+ st.write("Generating image...")
48
+ try:
49
+ output_path = main(args) # Call your backend main function
50
+ st.write("Image generation complete!")
51
+
52
+ # Display the generated image
53
+ output_image_path = os.path.join(output_dir, "generated_image.png") # Update if needed
54
+ if os.path.exists(output_image_path):
55
+ output_image = Image.open(output_image_path)
56
+ st.image(output_image, caption="Generated Image", use_column_width=True)
57
+ else:
58
+ st.error("Image generation failed. No output file found.")
59
+ except Exception as e:
60
+ st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
62
+ st.error("Please upload a sketch before generating an image.")