File size: 2,709 Bytes
6ae61a6
8689a3c
831b686
8cbc242
8689a3c
 
 
 
 
 
 
 
 
 
290f7fe
8689a3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40cbb76
8689a3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import streamlit as st
import os
from PIL import Image
from src.eval import main  # Import the modified main function from evl.py

# Title and Description
st.title("Fashion Image Generator")
st.write("Upload a rough sketch, set parameters, and generate realistic garment images.")

# File Upload Section
uploaded_file = st.file_uploader("Upload your rough sketch (PNG, JPG, JPEG):", type=["png", "jpg", "jpeg"])

# Sidebar for Parameters
st.sidebar.title("Model Configuration")
pretrained_model_path = st.sidebar.text_input("Pretrained Model Path", "runwayml/stable-diffusion-inpainting")
dataset_path = st.sidebar.text_input("Dataset Path", "./datasets/dresscode")
output_dir = st.sidebar.text_input("Output Directory", "./outputs")
guidance_scale_sketch = st.sidebar.slider("Sketch Guidance Scale", 1.0, 10.0, 7.5)
batch_size = st.sidebar.number_input("Batch Size", min_value=1, max_value=16, value=1)
mixed_precision = st.sidebar.selectbox("Mixed Precision Mode", ["fp16", "fp32"], index=0)
seed = st.sidebar.number_input("Random Seed", value=42, step=1)

# Run Button
if st.button("Generate Image"):
    if uploaded_file:
        # Save uploaded sketch locally
        os.makedirs("temp_uploads", exist_ok=True)
        sketch_path = os.path.join("temp_uploads", uploaded_file.name)
        with open(sketch_path, "wb") as f:
            f.write(uploaded_file.getbuffer())

        # Prepare arguments for the backend
        args = {
            "pretrained_model_name_or_path": pretrained_model_path,
            "dataset": "dresscode",
            "dataset_path": dataset_path,
            "output_dir": output_dir,
            "guidance_scale": 7.5,
            "guidance_scale_sketch": guidance_scale_sketch,
            "mixed_precision": mixed_precision,
            "batch_size": batch_size,
            "seed": seed,
            "save_name": "generated_image",  # Output file name
        }

        # Run the backend model
        st.write("Generating image...")
        try:
            output_path = main(args)  # Call your backend main function
            st.write("Image generation complete!")

            # Display the generated image
            output_image_path = os.path.join(output_dir, "generated_image.png")  # Update if needed
            if os.path.exists(output_image_path):
                output_image = Image.open(output_image_path)
                st.image(output_image, caption="Generated Image", use_column_width=True)
            else:
                st.error("Image generation failed. No output file found.")
        except Exception as e:
            st.error(f"An error occurred: {e}")
    else:
        st.error("Please upload a sketch before generating an image.")