Spaces:
Running
on
Zero
Running
on
Zero
Update to ZeroGPU
Browse files- app.py +1 -1
- model/barlow_twins.py +2 -0
- model/base_model.py +2 -0
- model/model.py +2 -1
- utils/sequence.py +2 -0
app.py
CHANGED
@@ -128,4 +128,4 @@ iface = gr.Interface(
|
|
128 |
theme=theme
|
129 |
)
|
130 |
|
131 |
-
iface.launch()
|
|
|
128 |
theme=theme
|
129 |
)
|
130 |
|
131 |
+
iface.launch(share=True)
|
model/barlow_twins.py
CHANGED
@@ -9,10 +9,12 @@ import os
|
|
9 |
import pickle
|
10 |
import inspect
|
11 |
from tqdm.auto import trange
|
|
|
12 |
|
13 |
from model.base_model import BaseModel
|
14 |
|
15 |
|
|
|
16 |
class BarlowTwins(BaseModel):
|
17 |
def __init__(
|
18 |
self,
|
|
|
9 |
import pickle
|
10 |
import inspect
|
11 |
from tqdm.auto import trange
|
12 |
+
import spaces
|
13 |
|
14 |
from model.base_model import BaseModel
|
15 |
|
16 |
|
17 |
+
@spaces.GPU
|
18 |
class BarlowTwins(BaseModel):
|
19 |
def __init__(
|
20 |
self,
|
model/base_model.py
CHANGED
@@ -2,8 +2,10 @@ from typing import Tuple, Any, Union
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
|
|
|
7 |
class BaseModel(nn.Module):
|
8 |
def __init__(self):
|
9 |
super(BaseModel, self).__init__()
|
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
import numpy as np
|
5 |
+
import spaces
|
6 |
|
7 |
|
8 |
+
@spaces.GPU
|
9 |
class BaseModel(nn.Module):
|
10 |
def __init__(self):
|
11 |
super(BaseModel, self).__init__()
|
model/model.py
CHANGED
@@ -17,6 +17,7 @@ import torch
|
|
17 |
from typing import *
|
18 |
from rdkit import RDLogger
|
19 |
RDLogger.DisableLog("rdApp.*")
|
|
|
20 |
|
21 |
from xgboost import XGBClassifier, DMatrix
|
22 |
|
@@ -26,7 +27,7 @@ from model.barlow_twins import BarlowTwins
|
|
26 |
from utils.sequence import uniprot2sequence, encode_sequences
|
27 |
|
28 |
|
29 |
-
|
30 |
class DTIModel:
|
31 |
def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
|
32 |
self.bt_model = BarlowTwins()
|
|
|
17 |
from typing import *
|
18 |
from rdkit import RDLogger
|
19 |
RDLogger.DisableLog("rdApp.*")
|
20 |
+
import spaces
|
21 |
|
22 |
from xgboost import XGBClassifier, DMatrix
|
23 |
|
|
|
27 |
from utils.sequence import uniprot2sequence, encode_sequences
|
28 |
|
29 |
|
30 |
+
@spaces.GPU
|
31 |
class DTIModel:
|
32 |
def __init__(self, bt_model_path: str, gbm_model_path: str, encoder: str = "prost_t5"):
|
33 |
self.bt_model = BarlowTwins()
|
utils/sequence.py
CHANGED
@@ -8,6 +8,7 @@ import concurrent.futures
|
|
8 |
from tqdm.auto import tqdm
|
9 |
import multiprocessing
|
10 |
from multiprocessing import Pool
|
|
|
11 |
|
12 |
|
13 |
ENCODERS = {
|
@@ -49,6 +50,7 @@ def uniprot2sequence(uniprot_id):
|
|
49 |
return None
|
50 |
|
51 |
|
|
|
52 |
def encode_sequences(sequences: list, encoder: str):
|
53 |
if encoder not in ENCODERS.keys():
|
54 |
raise ValueError(f"Invalid encoder: {encoder}")
|
|
|
8 |
from tqdm.auto import tqdm
|
9 |
import multiprocessing
|
10 |
from multiprocessing import Pool
|
11 |
+
import spaces
|
12 |
|
13 |
|
14 |
ENCODERS = {
|
|
|
50 |
return None
|
51 |
|
52 |
|
53 |
+
@spaces.GPU
|
54 |
def encode_sequences(sequences: list, encoder: str):
|
55 |
if encoder not in ENCODERS.keys():
|
56 |
raise ValueError(f"Invalid encoder: {encoder}")
|