mschuh commited on
Commit
c07eff0
β€’
1 Parent(s): 02327b3

Update for ZeroGPU

Browse files
Files changed (2) hide show
  1. model/barlow_twins.py +2 -1
  2. model/base_model.py +0 -2
model/barlow_twins.py CHANGED
@@ -14,7 +14,6 @@ import spaces
14
  from model.base_model import BaseModel
15
 
16
 
17
- @spaces.GPU
18
  class BarlowTwins(BaseModel):
19
  def __init__(
20
  self,
@@ -361,6 +360,7 @@ class BarlowTwins(BaseModel):
361
  if self.param_dict["verbose"] is True:
362
  print("[BT]: Training finished")
363
 
 
364
  def encode(
365
  self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol"
366
  ) -> np.ndarray:
@@ -406,6 +406,7 @@ class BarlowTwins(BaseModel):
406
  # convert back to numpy
407
  return embedding.cpu().detach().numpy()
408
 
 
409
  def zero_shot(
410
  self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu"
411
  ) -> np.ndarray:
 
14
  from model.base_model import BaseModel
15
 
16
 
 
17
  class BarlowTwins(BaseModel):
18
  def __init__(
19
  self,
 
360
  if self.param_dict["verbose"] is True:
361
  print("[BT]: Training finished")
362
 
363
+ @spaces.GPU
364
  def encode(
365
  self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol"
366
  ) -> np.ndarray:
 
406
  # convert back to numpy
407
  return embedding.cpu().detach().numpy()
408
 
409
+ @spaces.GPU
410
  def zero_shot(
411
  self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu"
412
  ) -> np.ndarray:
model/base_model.py CHANGED
@@ -2,10 +2,8 @@ from typing import Tuple, Any, Union
2
  import torch
3
  from torch import nn
4
  import numpy as np
5
- import spaces
6
 
7
 
8
- @spaces.GPU
9
  class BaseModel(nn.Module):
10
  def __init__(self):
11
  super(BaseModel, self).__init__()
 
2
  import torch
3
  from torch import nn
4
  import numpy as np
 
5
 
6
 
 
7
  class BaseModel(nn.Module):
8
  def __init__(self):
9
  super(BaseModel, self).__init__()