Spaces:
Running
Running
import streamlit as st | |
import torch | |
import os | |
from medigan import Generators | |
from torchvision.transforms.functional import to_pil_image | |
def install_and_patch_model(model_id): | |
try: | |
# Install model if not exists | |
generators = Generators() | |
if model_id not in generators.get_model_ids(): | |
generators.download_and_install_model(model_id) | |
# Locate model files | |
model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id) | |
# Patch base_GAN.py | |
base_gan_path = os.path.join(model_path, "model19", "base_GAN.py") | |
if os.path.exists(base_gan_path): | |
with open(base_gan_path, "r") as f: | |
content = f.read() | |
# Replace problematic optimizer lines | |
content = content.replace( | |
"betas=(self.beta1, self.beta2)", | |
"betas=(0.5, 0.999)" # Force correct beta values | |
) | |
with open(base_gan_path, "w") as f: | |
f.write(content) | |
st.success("Model successfully patched!") | |
return True | |
except Exception as e: | |
st.error(f"Patching failed: {str(e)}") | |
return False | |
MODEL_IDS = [ | |
"00001_DCGAN_MMG_CALC_ROI", | |
"00002_DCGAN_MMG_MASS_ROI", | |
"00003_CYCLEGAN_MMG_DENSITY_FULL", | |
"00004_PIX2PIX_MMG_MASSES_W_MASKS", | |
"00019_PGGAN_CHEST_XRAY" | |
] | |
def main(): | |
st.set_page_config(page_title="MEDIGAN Generator", layout="wide") | |
st.title("🧠 Medical Image Generator") | |
with st.sidebar: | |
st.header("⚙️ Settings") | |
model_id = st.selectbox("Select Model", MODEL_IDS) | |
num_images = st.slider("Number of Images", 1, 8, 4) | |
if st.button("✨ Generate Images"): | |
if "00019" in model_id and not install_and_patch_model(model_id): | |
return | |
with st.spinner("Generating images..."): | |
generate_images(num_images, model_id) | |
def generate_images(num_images, model_id): | |
try: | |
generators = Generators() | |
images = [] | |
for i in range(num_images): | |
sample = generators.generate( | |
model_id=model_id, | |
num_samples=1, | |
install_dependencies=False | |
) | |
img = to_pil_image(sample[0]).convert("RGB") | |
images.append(img) | |
cols = st.columns(4) | |
for idx, img in enumerate(images): | |
with cols[idx % 4]: | |
st.image(img, caption=f"Image {idx+1}", use_column_width=True) | |
st.markdown("---") | |
except Exception as e: | |
st.error(f"Generation failed: {str(e)}") | |
if __name__ == "__main__": | |
main() |