Spaces:
Runtime error
Runtime error
mueller-franzes
commited on
Commit
·
64a1bc8
1
Parent(s):
fec8e67
add device string
Browse files- streamlit/pages/chest.py +3 -3
- streamlit/pages/eye.py +3 -2
streamlit/pages/chest.py
CHANGED
@@ -16,15 +16,15 @@ guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, va
|
|
16 |
seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
|
17 |
cond_str = st.radio("Cardiomegaly", ('Yes', 'No'), index=1, help="Conditioned on 'cardiomegaly' or 'no cardiomegaly'", horizontal=True)
|
18 |
torch.manual_seed(seed)
|
19 |
-
|
20 |
-
device = torch.device(
|
21 |
|
22 |
@st.cache(allow_output_mutation = True)
|
23 |
def init_pipeline():
|
24 |
pipeline = DiffusionPipeline.load_from_checkpoint('runs/chest_diffusion/last.ckpt')
|
25 |
return pipeline
|
26 |
|
27 |
-
if st.button('Sample'):
|
28 |
cond = {'Yes':1, 'No':0}[cond_str]
|
29 |
condition = torch.tensor([cond]*n_samples, device=device)
|
30 |
un_cond = torch.tensor([1-cond]*n_samples, device=device)
|
|
|
16 |
seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
|
17 |
cond_str = st.radio("Cardiomegaly", ('Yes', 'No'), index=1, help="Conditioned on 'cardiomegaly' or 'no cardiomegaly'", horizontal=True)
|
18 |
torch.manual_seed(seed)
|
19 |
+
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
device = torch.device(device_str)
|
21 |
|
22 |
@st.cache(allow_output_mutation = True)
|
23 |
def init_pipeline():
|
24 |
pipeline = DiffusionPipeline.load_from_checkpoint('runs/chest_diffusion/last.ckpt')
|
25 |
return pipeline
|
26 |
|
27 |
+
if st.button(f'Sample (using {device_str})'):
|
28 |
cond = {'Yes':1, 'No':0}[cond_str]
|
29 |
condition = torch.tensor([cond]*n_samples, device=device)
|
30 |
un_cond = torch.tensor([1-cond]*n_samples, device=device)
|
streamlit/pages/eye.py
CHANGED
@@ -17,14 +17,15 @@ guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, va
|
|
17 |
seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
|
18 |
cond_str = st.radio("Glaucoma", ('Yes', 'No'), index=1, help="Conditioned on 'referable glaucoma' or 'no referable glaucoma'", horizontal=True)
|
19 |
torch.manual_seed(seed)
|
20 |
-
|
|
|
21 |
|
22 |
@st.cache(allow_output_mutation = True)
|
23 |
def init_pipeline():
|
24 |
pipeline = DiffusionPipeline.load_from_checkpoint('runs/eye_diffusion/last.ckpt')
|
25 |
return pipeline
|
26 |
|
27 |
-
if st.button('Sample'):
|
28 |
cond = {'Yes':1, 'No':0}[cond_str]
|
29 |
condition = torch.tensor([cond]*n_samples, device=device)
|
30 |
un_cond = torch.tensor([1-cond]*n_samples, device=device)
|
|
|
17 |
seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
|
18 |
cond_str = st.radio("Glaucoma", ('Yes', 'No'), index=1, help="Conditioned on 'referable glaucoma' or 'no referable glaucoma'", horizontal=True)
|
19 |
torch.manual_seed(seed)
|
20 |
+
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
device = torch.device(device_str)
|
22 |
|
23 |
@st.cache(allow_output_mutation = True)
|
24 |
def init_pipeline():
|
25 |
pipeline = DiffusionPipeline.load_from_checkpoint('runs/eye_diffusion/last.ckpt')
|
26 |
return pipeline
|
27 |
|
28 |
+
if st.button(f'Sample (using {device_str})'):
|
29 |
cond = {'Yes':1, 'No':0}[cond_str]
|
30 |
condition = torch.tensor([cond]*n_samples, device=device)
|
31 |
un_cond = torch.tensor([1-cond]*n_samples, device=device)
|