fix model device
Browse files
app.py
CHANGED
@@ -84,11 +84,18 @@ def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()):
|
|
84 |
try:
|
85 |
run_id = uuid4()
|
86 |
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
87 |
|
88 |
# gr.Info("Loading model...")
|
89 |
_ = lm_embed("M", use_cuda = (device.type == "cuda"))
|
90 |
|
91 |
-
model = DSCRIPTModel.from_pretrained(model_map[model_name], use_cuda=
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# gr.Info("Loading files...")
|
94 |
try:
|
|
|
84 |
try:
|
85 |
run_id = uuid4()
|
86 |
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
87 |
+
use_cuda = torch.cuda.is_available()
|
88 |
|
89 |
# gr.Info("Loading model...")
|
90 |
_ = lm_embed("M", use_cuda = (device.type == "cuda"))
|
91 |
|
92 |
+
model = DSCRIPTModel.from_pretrained(model_map[model_name], use_cuda=use_cuda)
|
93 |
+
if use_cuda:
|
94 |
+
model = model.to(device)
|
95 |
+
model.use_cuda = True
|
96 |
+
else:
|
97 |
+
model = model.to("cpu")
|
98 |
+
model.use_cuda = False
|
99 |
|
100 |
# gr.Info("Loading files...")
|
101 |
try:
|