mueller-franzes commited on
Commit
64a1bc8
·
1 Parent(s): fec8e67

add device string

Browse files
streamlit/pages/chest.py CHANGED
@@ -16,15 +16,15 @@ guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, va
16
  seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
17
  cond_str = st.radio("Cardiomegaly", ('Yes', 'No'), index=1, help="Conditioned on 'cardiomegaly' or 'no cardiomegaly'", horizontal=True)
18
  torch.manual_seed(seed)
19
-
20
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
 
22
  @st.cache(allow_output_mutation = True)
23
  def init_pipeline():
24
  pipeline = DiffusionPipeline.load_from_checkpoint('runs/chest_diffusion/last.ckpt')
25
  return pipeline
26
 
27
- if st.button('Sample'):
28
  cond = {'Yes':1, 'No':0}[cond_str]
29
  condition = torch.tensor([cond]*n_samples, device=device)
30
  un_cond = torch.tensor([1-cond]*n_samples, device=device)
 
16
  seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
17
  cond_str = st.radio("Cardiomegaly", ('Yes', 'No'), index=1, help="Conditioned on 'cardiomegaly' or 'no cardiomegaly'", horizontal=True)
18
  torch.manual_seed(seed)
19
+ device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ device = torch.device(device_str)
21
 
22
  @st.cache(allow_output_mutation = True)
23
  def init_pipeline():
24
  pipeline = DiffusionPipeline.load_from_checkpoint('runs/chest_diffusion/last.ckpt')
25
  return pipeline
26
 
27
+ if st.button(f'Sample (using {device_str})'):
28
  cond = {'Yes':1, 'No':0}[cond_str]
29
  condition = torch.tensor([cond]*n_samples, device=device)
30
  un_cond = torch.tensor([1-cond]*n_samples, device=device)
streamlit/pages/eye.py CHANGED
@@ -17,14 +17,15 @@ guidance_scale = st.number_input("Guidance scale", min_value=1, max_value=10, va
17
  seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
18
  cond_str = st.radio("Glaucoma", ('Yes', 'No'), index=1, help="Conditioned on 'referable glaucoma' or 'no referable glaucoma'", horizontal=True)
19
  torch.manual_seed(seed)
20
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
21
 
22
  @st.cache(allow_output_mutation = True)
23
  def init_pipeline():
24
  pipeline = DiffusionPipeline.load_from_checkpoint('runs/eye_diffusion/last.ckpt')
25
  return pipeline
26
 
27
- if st.button('Sample'):
28
  cond = {'Yes':1, 'No':0}[cond_str]
29
  condition = torch.tensor([cond]*n_samples, device=device)
30
  un_cond = torch.tensor([1-cond]*n_samples, device=device)
 
17
  seed = st.number_input("Seed", min_value=0, max_value=None, value=1)
18
  cond_str = st.radio("Glaucoma", ('Yes', 'No'), index=1, help="Conditioned on 'referable glaucoma' or 'no referable glaucoma'", horizontal=True)
19
  torch.manual_seed(seed)
20
+ device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ device = torch.device(device_str)
22
 
23
  @st.cache(allow_output_mutation = True)
24
  def init_pipeline():
25
  pipeline = DiffusionPipeline.load_from_checkpoint('runs/eye_diffusion/last.ckpt')
26
  return pipeline
27
 
28
+ if st.button(f'Sample (using {device_str})'):
29
  cond = {'Yes':1, 'No':0}[cond_str]
30
  condition = torch.tensor([cond]*n_samples, device=device)
31
  un_cond = torch.tensor([1-cond]*n_samples, device=device)