1aurent commited on
Commit
0510130
·
verified ·
1 Parent(s): d689c5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -5,6 +5,8 @@ import gradio as gr
5
  import numpy as np
6
 
7
  pipeline = DiffusionPipeline.from_pretrained("1aurent/ddpm-mnist")
 
 
8
 
9
  def predict(steps, seed):
10
  generator = torch.manual_seed(seed)
 
5
  import numpy as np
6
 
7
  pipeline = DiffusionPipeline.from_pretrained("1aurent/ddpm-mnist")
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ pipeline = pipeline.to(device=device)
10
 
11
  def predict(steps, seed):
12
  generator = torch.manual_seed(seed)