Spaces:
Runtime error
Runtime error
Commit
·
725dc81
1
Parent(s):
b43e55e
Update generate.py
Browse files- generate.py +5 -5
generate.py
CHANGED
@@ -21,17 +21,17 @@ TYPE_WRITER=1 # whether output streamly
|
|
21 |
|
22 |
args = parser.parse_args()
|
23 |
print(args)
|
24 |
-
tokenizer = LlamaTokenizer.from_pretrained(
|
25 |
|
26 |
LOAD_8BIT = True
|
27 |
|
28 |
|
29 |
|
30 |
# fix the path for local checkpoint
|
31 |
-
lora_bin_path = os.path.join(
|
32 |
print(lora_bin_path)
|
33 |
-
if not os.path.exists(lora_bin_path) and
|
34 |
-
pytorch_bin_path = os.path.join(
|
35 |
print(pytorch_bin_path)
|
36 |
if os.path.exists(pytorch_bin_path):
|
37 |
os.rename(pytorch_bin_path, lora_bin_path)
|
@@ -140,7 +140,7 @@ def evaluate(
|
|
140 |
**kwargs,
|
141 |
)
|
142 |
with torch.no_grad():
|
143 |
-
if
|
144 |
for generation_output in model.stream_generate(
|
145 |
input_ids=input_ids,
|
146 |
generation_config=generation_config,
|
|
|
21 |
|
22 |
args = parser.parse_args()
|
23 |
print(args)
|
24 |
+
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODE)
|
25 |
|
26 |
LOAD_8BIT = True
|
27 |
|
28 |
|
29 |
|
30 |
# fix the path for local checkpoint
|
31 |
+
lora_bin_path = os.path.join(LORA_PATH, "adapter_model.bin")
|
32 |
print(lora_bin_path)
|
33 |
+
if not os.path.exists(lora_bin_path) and USE_LOCAL:
|
34 |
+
pytorch_bin_path = os.path.join(LORA_PATH, "pytorch_model.bin")
|
35 |
print(pytorch_bin_path)
|
36 |
if os.path.exists(pytorch_bin_path):
|
37 |
os.rename(pytorch_bin_path, lora_bin_path)
|
|
|
140 |
**kwargs,
|
141 |
)
|
142 |
with torch.no_grad():
|
143 |
+
if TYPE_WRITER:
|
144 |
for generation_output in model.stream_generate(
|
145 |
input_ids=input_ids,
|
146 |
generation_config=generation_config,
|