Mairaaa commited on
Commit
74d4e67
·
verified ·
1 Parent(s): 1e76609

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -27
app.py CHANGED
@@ -9,37 +9,41 @@ from src.mgd_pipelines.mgd_pipe import MGDPipe
9
  # Initialize the model and other components
10
  @st.cache_resource
11
  def load_model():
12
- # Define your model loading logic
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
15
- print("VAE model loaded successfully.")
16
- except OSError as e:
17
- print(f"Error loading the model: {e}")
18
- tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
19
- text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
20
- unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
21
- scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
22
 
23
- pipe = MGDPipe(
24
- text_encoder=text_encoder,
25
- vae=vae,
26
- unet=unet.to(vae.dtype),
27
- tokenizer=tokenizer,
28
- scheduler=scheduler,
29
- ).to(device)
30
- return pipe
 
 
 
31
 
32
  pipe = load_model()
33
 
34
  def generate_images(pipe, text_input=None, sketch=None):
35
  # Generate images from text or sketch or both
36
  images = []
37
- if text_input:
38
- prompt = [text_input]
39
- images.extend(pipe(prompt=prompt))
40
- if sketch:
41
- sketch_image = Image.open(sketch).convert("RGB")
42
- images.extend(pipe(sketch=sketch_image))
 
43
  return images
44
 
45
  # Streamlit UI
@@ -65,6 +69,9 @@ if st.button("Generate"):
65
  sketches = BytesIO(sketch_file.read()) if sketch_file else None
66
  images = generate_images(pipe, text_input=text_input, sketch=sketches)
67
 
68
- # Display results
69
- for i, img in enumerate(images):
70
- st.image(img, caption=f"Generated Image {i+1}")
 
 
 
 
9
  # Initialize the model and other components
10
  @st.cache_resource
11
  def load_model():
12
+ try:
13
+ # Define your model loading logic
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
16
+ print("VAE model loaded successfully.")
17
+
18
+ tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer")
19
+ text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder")
20
+ unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
21
+ scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
22
 
23
+ pipe = MGDPipe(
24
+ text_encoder=text_encoder,
25
+ vae=vae,
26
+ unet=unet.to(vae.dtype),
27
+ tokenizer=tokenizer,
28
+ scheduler=scheduler,
29
+ ).to(device)
30
+ return pipe
31
+ except OSError as e:
32
+ print(f"Error loading the model: {e}")
33
+ return None
34
 
35
  pipe = load_model()
36
 
37
  def generate_images(pipe, text_input=None, sketch=None):
38
  # Generate images from text or sketch or both
39
  images = []
40
+ if pipe:
41
+ if text_input:
42
+ prompt = [text_input]
43
+ images.extend(pipe(prompt=prompt))
44
+ if sketch:
45
+ sketch_image = Image.open(sketch).convert("RGB")
46
+ images.extend(pipe(sketch=sketch_image))
47
  return images
48
 
49
  # Streamlit UI
 
69
  sketches = BytesIO(sketch_file.read()) if sketch_file else None
70
  images = generate_images(pipe, text_input=text_input, sketch=sketches)
71
 
72
+ if images:
73
+ # Display results
74
+ for i, img in enumerate(images):
75
+ st.image(img, caption=f"Generated Image {i+1}")
76
+ else:
77
+ st.error("Failed to generate images. Please check the model or inputs.")