Generate_Images / app.py
ahmed792002's picture
Update app.py
ef30ad0 verified
raw
history blame
2.79 kB
import streamlit as st
import torch
import os
from medigan import Generators
from torchvision.transforms.functional import to_pil_image
def install_and_patch_model(model_id):
try:
# Install model if not exists
generators = Generators()
if model_id not in generators.get_model_ids():
generators.download_and_install_model(model_id)
# Locate model files
model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id)
# Patch base_GAN.py
base_gan_path = os.path.join(model_path, "model19", "base_GAN.py")
if os.path.exists(base_gan_path):
with open(base_gan_path, "r") as f:
content = f.read()
# Replace problematic optimizer lines
content = content.replace(
"betas=(self.beta1, self.beta2)",
"betas=(0.5, 0.999)" # Force correct beta values
)
with open(base_gan_path, "w") as f:
f.write(content)
st.success("Model successfully patched!")
return True
except Exception as e:
st.error(f"Patching failed: {str(e)}")
return False
MODEL_IDS = [
"00001_DCGAN_MMG_CALC_ROI",
"00002_DCGAN_MMG_MASS_ROI",
"00003_CYCLEGAN_MMG_DENSITY_FULL",
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
"00019_PGGAN_CHEST_XRAY"
]
def main():
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
st.title("🧠 Medical Image Generator")
with st.sidebar:
st.header("⚙️ Settings")
model_id = st.selectbox("Select Model", MODEL_IDS)
num_images = st.slider("Number of Images", 1, 8, 4)
if st.button("✨ Generate Images"):
if "00019" in model_id and not install_and_patch_model(model_id):
return
with st.spinner("Generating images..."):
generate_images(num_images, model_id)
def generate_images(num_images, model_id):
try:
generators = Generators()
images = []
for i in range(num_images):
sample = generators.generate(
model_id=model_id,
num_samples=1,
install_dependencies=False
)
img = to_pil_image(sample[0]).convert("RGB")
images.append(img)
cols = st.columns(4)
for idx, img in enumerate(images):
with cols[idx % 4]:
st.image(img, caption=f"Image {idx+1}", use_column_width=True)
st.markdown("---")
except Exception as e:
st.error(f"Generation failed: {str(e)}")
if __name__ == "__main__":
main()