ve-gans / space.py
Zai
updated the model loader and readme
c6494da
raw
history blame
1.94 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import numpy as np
# GAN model architecture
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# TO DO: implement generator architecture
def forward(self, z):
# TO DO: implement generator forward pass
pass
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# TO DO: implement discriminator architecture
def forward(self, x):
# TO DO: implement discriminator forward pass
pass
# Floorplan generator model architecture
class FloorplanGenerator(nn.Module):
def __init__(self):
super(FloorplanGenerator, self).__init__()
# TO DO: implement floorplan generator architecture
def forward(self, img):
# TO DO: implement floorplan generator forward pass
pass
# Load pre-trained models
@st.cache
def load_models():
generator = Generator()
discriminator = Discriminator()
floorplan_generator = FloorplanGenerator()
# TO DO: load pre-trained model weights
return generator, discriminator, floorplan_generator
# Streamlit app
st.title("GAN Image Generation and Floorplan App")
# Load models
generator, discriminator, floorplan_generator = load_models()
# Image generation
st.header("Image Generation")
z_dim = 100
noise = torch.randn(1, z_dim)
generated_img = generator(noise)
st.image(generated_img, caption="Generated Image")
# Floorplan generation
st.header("Floorplan Generation")
img_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if img_file:
img = Image.open(img_file)
img = np.array(img) / 255.0
floorplan = floorplan_generator(torch.tensor(img))
st.image(floorplan, caption="Generated Floorplan")
# Run the app
if __name__ == "__main__":
st.write("App is running!")