Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,37 +9,41 @@ from src.mgd_pipelines.mgd_pipe import MGDPipe
|
|
9 |
# Initialize the model and other components
|
10 |
@st.cache_resource
|
11 |
def load_model():
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
pipe = load_model()
|
33 |
|
34 |
def generate_images(pipe, text_input=None, sketch=None):
|
35 |
# Generate images from text or sketch or both
|
36 |
images = []
|
37 |
-
if
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
43 |
return images
|
44 |
|
45 |
# Streamlit UI
|
@@ -65,6 +69,9 @@ if st.button("Generate"):
|
|
65 |
sketches = BytesIO(sketch_file.read()) if sketch_file else None
|
66 |
images = generate_images(pipe, text_input=text_input, sketch=sketches)
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
9 |
# Initialize the model and other components
|
10 |
@st.cache_resource
|
11 |
def load_model():
|
12 |
+
try:
|
13 |
+
# Define your model loading logic
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
16 |
+
print("VAE model loaded successfully.")
|
17 |
+
|
18 |
+
tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
|
19 |
+
text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
|
20 |
+
unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
|
21 |
+
scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
|
22 |
|
23 |
+
pipe = MGDPipe(
|
24 |
+
text_encoder=text_encoder,
|
25 |
+
vae=vae,
|
26 |
+
unet=unet.to(vae.dtype),
|
27 |
+
tokenizer=tokenizer,
|
28 |
+
scheduler=scheduler,
|
29 |
+
).to(device)
|
30 |
+
return pipe
|
31 |
+
except OSError as e:
|
32 |
+
print(f"Error loading the model: {e}")
|
33 |
+
return None
|
34 |
|
35 |
pipe = load_model()
|
36 |
|
37 |
def generate_images(pipe, text_input=None, sketch=None):
|
38 |
# Generate images from text or sketch or both
|
39 |
images = []
|
40 |
+
if pipe:
|
41 |
+
if text_input:
|
42 |
+
prompt = [text_input]
|
43 |
+
images.extend(pipe(prompt=prompt))
|
44 |
+
if sketch:
|
45 |
+
sketch_image = Image.open(sketch).convert("RGB")
|
46 |
+
images.extend(pipe(sketch=sketch_image))
|
47 |
return images
|
48 |
|
49 |
# Streamlit UI
|
|
|
69 |
sketches = BytesIO(sketch_file.read()) if sketch_file else None
|
70 |
images = generate_images(pipe, text_input=text_input, sketch=sketches)
|
71 |
|
72 |
+
if images:
|
73 |
+
# Display results
|
74 |
+
for i, img in enumerate(images):
|
75 |
+
st.image(img, caption=f"Generated Image {i+1}")
|
76 |
+
else:
|
77 |
+
st.error("Failed to generate images. Please check the model or inputs.")
|