Mairaaa commited on
Commit
4198ed7
·
verified ·
1 Parent(s): ea12b33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -53
app.py CHANGED
@@ -1,62 +1,77 @@
1
- import streamlit as st
2
  import os
 
 
3
  from PIL import Image
4
- from src.eval import main # Import the modified main function from evl.py
5
-
6
- # Title and Description
7
- st.title("Fashion Image Generator")
8
- st.write("Upload a rough sketch, set parameters, and generate realistic garment images.")
9
 
10
- # File Upload Section
11
- uploaded_file = st.file_uploader("Upload your rough sketch (PNG, JPG, JPEG):", type=["png", "jpg", "jpeg"])
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Sidebar for Parameters
14
- st.sidebar.title("Model Configuration")
15
- pretrained_model_path = st.sidebar.text_input("Pretrained Model Path", "runwayml/stable-diffusion-inpainting")
16
- dataset_path = st.sidebar.text_input("Dataset Path", "./datasets/dresscode")
17
- output_dir = st.sidebar.text_input("Output Directory", "./outputs")
18
- guidance_scale_sketch = st.sidebar.slider("Sketch Guidance Scale", 1.0, 10.0, 7.5)
19
- batch_size = st.sidebar.number_input("Batch Size", min_value=1, max_value=16, value=1)
20
- mixed_precision = st.sidebar.selectbox("Mixed Precision Mode", ["fp16", "fp32"], index=0)
21
- seed = st.sidebar.number_input("Random Seed", value=42, step=1)
 
 
 
 
 
 
 
 
22
 
23
- # Run Button
24
- if st.button("Generate Image"):
25
- if uploaded_file:
26
- # Save uploaded sketch locally
27
- os.makedirs("temp_uploads", exist_ok=True)
28
- sketch_path = os.path.join("temp_uploads", uploaded_file.name)
29
- with open(sketch_path, "wb") as f:
30
- f.write(uploaded_file.getbuffer())
31
 
32
- # Prepare arguments for the backend
33
- args = {
34
- "pretrained_model_name_or_path": pretrained_model_path,
35
- "dataset": "dresscode",
36
- "dataset_path": dataset_path,
37
- "output_dir": output_dir,
38
- "guidance_scale": 7.5,
39
- "guidance_scale_sketch": guidance_scale_sketch,
40
- "mixed_precision": mixed_precision,
41
- "batch_size": batch_size,
42
- "seed": seed,
43
- "save_name": "generated_image", # Output file name
44
- }
45
 
46
- # Run the backend model
47
- st.write("Generating image...")
48
- try:
49
- output_path = main(args) # Call your backend main function
50
- st.write("Image generation complete!")
51
 
52
- # Display the generated image
53
- output_image_path = os.path.join(output_dir, "generated_image.png") # Update if needed
54
- if os.path.exists(output_image_path):
55
- output_image = Image.open(output_image_path)
56
- st.image(output_image, caption="Generated Image", use_column_width=True)
57
- else:
58
- st.error("Image generation failed. No output file found.")
59
- except Exception as e:
60
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
61
  else:
62
- st.error("Please upload a sketch before generating an image.")
 
 
1
  import os
2
+ import torch
3
+ import streamlit as st
4
  from PIL import Image
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers import AutoencoderKL, DDIMScheduler
7
+ from src.mgd_pipelines.mgd_pipe import MGDPipe
8
+ from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
 
9
 
10
+ # Function to load models
11
+ def load_models(pretrained_model_name_or_path, device):
12
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
13
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
14
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
15
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
16
+ scheduler.set_timesteps(50, device=device)
17
+
18
+ unet = torch.hub.load(
19
+ repo_or_dir="aimagelab/multimodal-garment-designer",
20
+ model="mgd",
21
+ pretrained=True,
22
+ source="github",
23
+ )
24
+ return tokenizer, text_encoder, vae, scheduler, unet
25
 
26
+ # Function to generate images
27
+ def generate_image(sketch, prompt, tokenizer, text_encoder, vae, scheduler, unet, device):
28
+ # Preprocess inputs
29
+ sketch = sketch.resize((512, 384)).convert("RGB")
30
+ sketch_tensor = torch.tensor([torch.tensor(sketch, dtype=torch.float32).permute(2, 0, 1) / 255.0]).to(device)
31
+
32
+ # Tokenize prompt
33
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
34
+
35
+ # Initialize pipeline
36
+ pipeline = MGDPipe(
37
+ text_encoder=text_encoder.to(device),
38
+ vae=vae.to(device),
39
+ unet=unet.to(device),
40
+ tokenizer=tokenizer,
41
+ scheduler=scheduler,
42
+ ).to(device)
43
 
44
+ # Generate image
45
+ pipeline.enable_attention_slicing()
46
+ with torch.inference_mode():
47
+ outputs = pipeline(images=sketch_tensor, text=inputs["input_ids"], guidance_scale=7.5)
48
+
49
+ return outputs[0]
 
 
50
 
51
+ # Streamlit UI
52
+ st.title("Garment Designer")
53
+ st.write("Upload a sketch and provide a text description to generate garment designs!")
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # User Inputs
56
+ uploaded_file = st.file_uploader("Upload your sketch", type=["png", "jpg", "jpeg"])
57
+ text_prompt = st.text_input("Enter a text description for the garment")
 
 
58
 
59
+ # Generate button
60
+ if st.button("Generate"):
61
+ if uploaded_file and text_prompt:
62
+ # Load models
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ pretrained_model_path = "your-pretrained-model-path" # Replace with actual model path
65
+ tokenizer, text_encoder, vae, scheduler, unet = load_models(pretrained_model_path, device)
66
+
67
+ # Load sketch
68
+ sketch = Image.open(uploaded_file)
69
+
70
+ # Generate image
71
+ st.write("Generating the garment design...")
72
+ output_image = generate_image(sketch, text_prompt, tokenizer, text_encoder, vae, scheduler, unet, device)
73
+
74
+ # Display output
75
+ st.image(output_image, caption="Generated Garment Design", use_column_width=True)
76
  else:
77
+ st.error("Please upload a sketch and enter a text description.")