Spaces:
Runtime error
Runtime error
Commit
·
77a996f
1
Parent(s):
db804d4
update model weight path
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ from torchtext.data.utils import get_tokenizer
|
|
4 |
import numpy as np
|
5 |
import subprocess
|
6 |
|
7 |
-
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from transformer import Transformer
|
10 |
|
@@ -30,7 +29,7 @@ src_vocab = len(vocab)
|
|
30 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
model = Transformer(len(vocab), len(vocab), d_model, N, heads).to(device)
|
32 |
model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
|
33 |
-
filename="
|
34 |
model.eval()
|
35 |
|
36 |
def respond(custom_string):
|
|
|
4 |
import numpy as np
|
5 |
import subprocess
|
6 |
|
|
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
from transformer import Transformer
|
9 |
|
|
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
model = Transformer(len(vocab), len(vocab), d_model, N, heads).to(device)
|
31 |
model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
|
32 |
+
filename="alpaca_train_400_epoch.pt"), map_location=device))
|
33 |
model.eval()
|
34 |
|
35 |
def respond(custom_string):
|