Mairaaa commited on
Commit
f26e6e9
·
verified ·
1 Parent(s): e27303d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -20,6 +20,7 @@ def load_model():
20
  unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
21
  scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
22
 
 
23
  pipe = MGDPipe(
24
  text_encoder=text_encoder,
25
  vae=vae,
@@ -27,8 +28,9 @@ def load_model():
27
  tokenizer=tokenizer,
28
  scheduler=scheduler,
29
  ).to(device)
 
30
  return pipe
31
- except OSError as e:
32
  print(f"Error loading the model: {e}")
33
  return None
34
 
@@ -37,32 +39,48 @@ pipe = load_model()
37
  def generate_images(pipe, text_input=None, sketch=None):
38
  # Generate images from text or sketch or both
39
  images = []
40
- if pipe:
41
- if text_input:
42
- prompt = [text_input]
43
- images.extend(pipe(prompt=prompt))
44
- if sketch:
45
- sketch_image = Image.open(sketch).convert("RGB")
46
- images.extend(pipe(sketch=sketch_image))
 
 
 
 
 
 
 
47
  return images
48
 
49
  # Streamlit UI
50
  st.title("Sketch & Text-based Image Generation")
51
  st.write("Generate images based on rough sketches, text input, or both.")
52
 
 
53
  option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
54
 
55
  sketch_file = None
56
  text_input = None
57
 
 
58
  if option in ["Sketch", "Both"]:
59
  sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
60
 
 
61
  if option in ["Text", "Both"]:
62
  text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
63
 
 
64
  if st.button("Generate"):
65
- # Ensure text_input and sketches are handled properly
 
 
 
 
 
66
  sketches = BytesIO(sketch_file.read()) if sketch_file else None
67
 
68
  if option == "Sketch" and not sketch_file:
@@ -72,13 +90,16 @@ if st.button("Generate"):
72
  elif option == "Both" and not (sketch_file or text_input):
73
  st.error("Please provide both a sketch and a text prompt.")
74
  else:
75
- # Generate images based on user input
76
  with st.spinner("Generating images..."):
77
  images = generate_images(pipe, text_input=text_input, sketch=sketches)
78
 
 
79
  if images:
80
- # Display results
81
  for i, img in enumerate(images):
 
 
 
82
  st.image(img, caption=f"Generated Image {i+1}")
83
  else:
84
- st.error("Failed to generate images. Please check the model or inputs.")
 
20
  unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True)
21
  scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler")
22
 
23
+ # Initialize the pipeline
24
  pipe = MGDPipe(
25
  text_encoder=text_encoder,
26
  vae=vae,
 
28
  tokenizer=tokenizer,
29
  scheduler=scheduler,
30
  ).to(device)
31
+ pipe.enable_attention_slicing() # Enable memory-efficient inference
32
  return pipe
33
+ except Exception as e:
34
  print(f"Error loading the model: {e}")
35
  return None
36
 
 
39
  def generate_images(pipe, text_input=None, sketch=None):
40
  # Generate images from text or sketch or both
41
  images = []
42
+ try:
43
+ if pipe:
44
+ # Generate from text
45
+ if text_input:
46
+ print(f"Generating image from text: {text_input}")
47
+ images.append(pipe(prompt=[text_input]))
48
+
49
+ # Generate from sketch
50
+ if sketch:
51
+ print("Generating image from sketch.")
52
+ sketch_image = Image.open(sketch).convert("RGB")
53
+ images.append(pipe(sketch=sketch_image))
54
+ except Exception as e:
55
+ print(f"Error during image generation: {e}")
56
  return images
57
 
58
  # Streamlit UI
59
  st.title("Sketch & Text-based Image Generation")
60
  st.write("Generate images based on rough sketches, text input, or both.")
61
 
62
+ # Input options
63
  option = st.radio("Select Input Type", ("Sketch", "Text", "Both"))
64
 
65
  sketch_file = None
66
  text_input = None
67
 
68
+ # Get sketch input
69
  if option in ["Sketch", "Both"]:
70
  sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"])
71
 
72
+ # Get text input
73
  if option in ["Text", "Both"]:
74
  text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate")
75
 
76
+ # Generate button
77
  if st.button("Generate"):
78
+ # Ensure the model is loaded
79
+ if pipe is None:
80
+ st.error("Model failed to load. Please restart the application.")
81
+ st.stop()
82
+
83
+ # Validate inputs
84
  sketches = BytesIO(sketch_file.read()) if sketch_file else None
85
 
86
  if option == "Sketch" and not sketch_file:
 
90
  elif option == "Both" and not (sketch_file or text_input):
91
  st.error("Please provide both a sketch and a text prompt.")
92
  else:
93
+ # Generate images
94
  with st.spinner("Generating images..."):
95
  images = generate_images(pipe, text_input=text_input, sketch=sketches)
96
 
97
+ # Display results
98
  if images:
 
99
  for i, img in enumerate(images):
100
+ if isinstance(img, torch.Tensor): # Convert tensor to image
101
+ img = img.squeeze().permute(1, 2, 0).cpu().numpy()
102
+ img = Image.fromarray((img * 255).astype("uint8"))
103
  st.image(img, caption=f"Generated Image {i+1}")
104
  else:
105
+ st.error("Failed to generate images. Please check the inputs or model configuration.")