Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from PIL import Image | |
from src.eval import main # Import the modified main function from evl.py | |
# Title and Description | |
st.title("Fashion Image Generator") | |
st.write("Upload a rough sketch, set parameters, and generate realistic garment images.") | |
# File Upload Section | |
uploaded_file = st.file_uploader("Upload your rough sketch (PNG, JPG, JPEG):", type=["png", "jpg", "jpeg"]) | |
# Sidebar for Parameters | |
st.sidebar.title("Model Configuration") | |
pretrained_model_path = st.sidebar.text_input("Pretrained Model Path", "./models") | |
dataset_path = st.sidebar.text_input("Dataset Path", "./datasets/dresscode") | |
output_dir = st.sidebar.text_input("Output Directory", "./outputs") | |
guidance_scale_sketch = st.sidebar.slider("Sketch Guidance Scale", 1.0, 10.0, 7.5) | |
batch_size = st.sidebar.number_input("Batch Size", min_value=1, max_value=16, value=1) | |
mixed_precision = st.sidebar.selectbox("Mixed Precision Mode", ["fp16", "fp32"], index=0) | |
seed = st.sidebar.number_input("Random Seed", value=42, step=1) | |
# Run Button | |
if st.button("Generate Image"): | |
if uploaded_file: | |
# Save uploaded sketch locally | |
os.makedirs("temp_uploads", exist_ok=True) | |
sketch_path = os.path.join("temp_uploads", uploaded_file.name) | |
with open(sketch_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Prepare arguments for the backend | |
args = { | |
"pretrained_model_name_or_path": pretrained_model_path, | |
"dataset": "dresscode", | |
"dataset_path": dataset_path, | |
"output_dir": output_dir, | |
"guidance_scale": 7.5, | |
"guidance_scale_sketch": guidance_scale_sketch, | |
"mixed_precision": mixed_precision, | |
"batch_size": batch_size, | |
"seed": seed, | |
"save_name": "generated_image", # Output file name | |
} | |
# Run the backend model | |
st.write("Generating image...") | |
try: | |
output_path = main(args) # Call your backend main function | |
st.write("Image generation complete!") | |
# Display the generated image | |
output_image_path = os.path.join(output_dir, "generated_image.png") # Update if needed | |
if os.path.exists(output_image_path): | |
output_image = Image.open(output_image_path) | |
st.image(output_image, caption="Generated Image", use_column_width=True) | |
else: | |
st.error("Image generation failed. No output file found.") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.error("Please upload a sketch before generating an image.") | |