samsl commited on
Commit
f7f9537
·
1 Parent(s): 420fd6e

fix model device

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -84,11 +84,18 @@ def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()):
84
  try:
85
  run_id = uuid4()
86
  device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
 
87
 
88
  # gr.Info("Loading model...")
89
  _ = lm_embed("M", use_cuda = (device.type == "cuda"))
90
 
91
- model = DSCRIPTModel.from_pretrained(model_map[model_name], use_cuda=torch.cuda.is_available())
 
 
 
 
 
 
92
 
93
  # gr.Info("Loading files...")
94
  try:
 
84
  try:
85
  run_id = uuid4()
86
  device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
87
+ use_cuda = torch.cuda.is_available()
88
 
89
  # gr.Info("Loading model...")
90
  _ = lm_embed("M", use_cuda = (device.type == "cuda"))
91
 
92
+ model = DSCRIPTModel.from_pretrained(model_map[model_name], use_cuda=use_cuda)
93
+ if use_cuda:
94
+ model = model.to(device)
95
+ model.use_cuda = True
96
+ else:
97
+ model = model.to("cpu")
98
+ model.use_cuda = False
99
 
100
  # gr.Info("Loading files...")
101
  try: