Spaces:
Running
Running
File size: 2,787 Bytes
c6d85c0 078cd45 ef30ad0 c6d85c0 ef30ad0 078cd45 ef30ad0 078cd45 ef30ad0 078cd45 ef30ad0 078cd45 ef30ad0 078cd45 ef30ad0 078cd45 ef30ad0 078cd45 c6d85c0 afbb2d7 c6d85c0 afbb2d7 078cd45 afbb2d7 078cd45 ef30ad0 078cd45 c6d85c0 ef30ad0 078cd45 ef30ad0 c6d85c0 afbb2d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import streamlit as st
import torch
import os
from medigan import Generators
from torchvision.transforms.functional import to_pil_image
def install_and_patch_model(model_id):
try:
# Install model if not exists
generators = Generators()
if model_id not in generators.get_model_ids():
generators.download_and_install_model(model_id)
# Locate model files
model_path = os.path.join(os.path.expanduser("~"), ".medigan", "models", model_id)
# Patch base_GAN.py
base_gan_path = os.path.join(model_path, "model19", "base_GAN.py")
if os.path.exists(base_gan_path):
with open(base_gan_path, "r") as f:
content = f.read()
# Replace problematic optimizer lines
content = content.replace(
"betas=(self.beta1, self.beta2)",
"betas=(0.5, 0.999)" # Force correct beta values
)
with open(base_gan_path, "w") as f:
f.write(content)
st.success("Model successfully patched!")
return True
except Exception as e:
st.error(f"Patching failed: {str(e)}")
return False
MODEL_IDS = [
"00001_DCGAN_MMG_CALC_ROI",
"00002_DCGAN_MMG_MASS_ROI",
"00003_CYCLEGAN_MMG_DENSITY_FULL",
"00004_PIX2PIX_MMG_MASSES_W_MASKS",
"00019_PGGAN_CHEST_XRAY"
]
def main():
st.set_page_config(page_title="MEDIGAN Generator", layout="wide")
st.title("🧠 Medical Image Generator")
with st.sidebar:
st.header("⚙️ Settings")
model_id = st.selectbox("Select Model", MODEL_IDS)
num_images = st.slider("Number of Images", 1, 8, 4)
if st.button("✨ Generate Images"):
if "00019" in model_id and not install_and_patch_model(model_id):
return
with st.spinner("Generating images..."):
generate_images(num_images, model_id)
def generate_images(num_images, model_id):
try:
generators = Generators()
images = []
for i in range(num_images):
sample = generators.generate(
model_id=model_id,
num_samples=1,
install_dependencies=False
)
img = to_pil_image(sample[0]).convert("RGB")
images.append(img)
cols = st.columns(4)
for idx, img in enumerate(images):
with cols[idx % 4]:
st.image(img, caption=f"Image {idx+1}", use_column_width=True)
st.markdown("---")
except Exception as e:
st.error(f"Generation failed: {str(e)}")
if __name__ == "__main__":
main() |