Spaces:
Running
on
Zero
Running
on
Zero
Update for ZeroGPU
Browse files- model/barlow_twins.py +2 -1
- model/base_model.py +0 -2
model/barlow_twins.py
CHANGED
@@ -14,7 +14,6 @@ import spaces
|
|
14 |
from model.base_model import BaseModel
|
15 |
|
16 |
|
17 |
-
@spaces.GPU
|
18 |
class BarlowTwins(BaseModel):
|
19 |
def __init__(
|
20 |
self,
|
@@ -361,6 +360,7 @@ class BarlowTwins(BaseModel):
|
|
361 |
if self.param_dict["verbose"] is True:
|
362 |
print("[BT]: Training finished")
|
363 |
|
|
|
364 |
def encode(
|
365 |
self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol"
|
366 |
) -> np.ndarray:
|
@@ -406,6 +406,7 @@ class BarlowTwins(BaseModel):
|
|
406 |
# convert back to numpy
|
407 |
return embedding.cpu().detach().numpy()
|
408 |
|
|
|
409 |
def zero_shot(
|
410 |
self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu"
|
411 |
) -> np.ndarray:
|
|
|
14 |
from model.base_model import BaseModel
|
15 |
|
16 |
|
|
|
17 |
class BarlowTwins(BaseModel):
|
18 |
def __init__(
|
19 |
self,
|
|
|
360 |
if self.param_dict["verbose"] is True:
|
361 |
print("[BT]: Training finished")
|
362 |
|
363 |
+
@spaces.GPU
|
364 |
def encode(
|
365 |
self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol"
|
366 |
) -> np.ndarray:
|
|
|
406 |
# convert back to numpy
|
407 |
return embedding.cpu().detach().numpy()
|
408 |
|
409 |
+
@spaces.GPU
|
410 |
def zero_shot(
|
411 |
self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu"
|
412 |
) -> np.ndarray:
|
model/base_model.py
CHANGED
@@ -2,10 +2,8 @@ from typing import Tuple, Any, Union
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
import numpy as np
|
5 |
-
import spaces
|
6 |
|
7 |
|
8 |
-
@spaces.GPU
|
9 |
class BaseModel(nn.Module):
|
10 |
def __init__(self):
|
11 |
super(BaseModel, self).__init__()
|
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
|
|
|
7 |
class BaseModel(nn.Module):
|
8 |
def __init__(self):
|
9 |
super(BaseModel, self).__init__()
|