import streamlit as st import torch import os from medigan import Generators from torchvision.transforms.functional import to_pil_image def install_and_patch_model(model_id): try: # Install model if not exists generators = Generators() if model_id not in generators.get_model_ids(): generators.download_and_install_model(model_id) # Locate model files model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id) # Patch base_GAN.py base_gan_path = os.path.join(model_path, "model19", "base_GAN.py") if os.path.exists(base_gan_path): with open(base_gan_path, "r") as f: content = f.read() # Replace problematic optimizer lines content = content.replace( "betas=(self.beta1, self.beta2)", "betas=(0.5, 0.999)" # Force correct beta values ) with open(base_gan_path, "w") as f: f.write(content) st.success("Model successfully patched!") return True except Exception as e: st.error(f"Patching failed: {str(e)}") return False MODEL_IDS = [ "00001_DCGAN_MMG_CALC_ROI", "00002_DCGAN_MMG_MASS_ROI", "00003_CYCLEGAN_MMG_DENSITY_FULL", "00004_PIX2PIX_MMG_MASSES_W_MASKS", "00019_PGGAN_CHEST_XRAY" ] def main(): st.set_page_config(page_title="MEDIGAN Generator", layout="wide") st.title("🧠 Medical Image Generator") with st.sidebar: st.header("⚙️ Settings") model_id = st.selectbox("Select Model", MODEL_IDS) num_images = st.slider("Number of Images", 1, 8, 4) if st.button("✨ Generate Images"): if "00019" in model_id and not install_and_patch_model(model_id): return with st.spinner("Generating images..."): generate_images(num_images, model_id) def generate_images(num_images, model_id): try: generators = Generators() images = [] for i in range(num_images): sample = generators.generate( model_id=model_id, num_samples=1, install_dependencies=False ) img = to_pil_image(sample[0]).convert("RGB") images.append(img) cols = st.columns(4) for idx, img in enumerate(images): with cols[idx % 4]: st.image(img, caption=f"Image {idx+1}", use_column_width=True) st.markdown("---") except Exception as e: st.error(f"Generation failed: {str(e)}") if __name__ == "__main__": main()