teachyourselfcoding commited on
Commit
725dc81
·
1 Parent(s): b43e55e

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +5 -5
generate.py CHANGED
@@ -21,17 +21,17 @@ TYPE_WRITER=1 # whether output streamly
21
 
22
  args = parser.parse_args()
23
  print(args)
24
- tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
25
 
26
  LOAD_8BIT = True
27
 
28
 
29
 
30
  # fix the path for local checkpoint
31
- lora_bin_path = os.path.join(args.lora_path, "adapter_model.bin")
32
  print(lora_bin_path)
33
- if not os.path.exists(lora_bin_path) and args.use_local:
34
- pytorch_bin_path = os.path.join(args.lora_path, "pytorch_model.bin")
35
  print(pytorch_bin_path)
36
  if os.path.exists(pytorch_bin_path):
37
  os.rename(pytorch_bin_path, lora_bin_path)
@@ -140,7 +140,7 @@ def evaluate(
140
  **kwargs,
141
  )
142
  with torch.no_grad():
143
- if args.use_typewriter:
144
  for generation_output in model.stream_generate(
145
  input_ids=input_ids,
146
  generation_config=generation_config,
 
21
 
22
  args = parser.parse_args()
23
  print(args)
24
+ tokenizer = LlamaTokenizer.from_pretrained(BASE_MODE)
25
 
26
  LOAD_8BIT = True
27
 
28
 
29
 
30
  # fix the path for local checkpoint
31
+ lora_bin_path = os.path.join(LORA_PATH, "adapter_model.bin")
32
  print(lora_bin_path)
33
+ if not os.path.exists(lora_bin_path) and USE_LOCAL:
34
+ pytorch_bin_path = os.path.join(LORA_PATH, "pytorch_model.bin")
35
  print(pytorch_bin_path)
36
  if os.path.exists(pytorch_bin_path):
37
  os.rename(pytorch_bin_path, lora_bin_path)
 
140
  **kwargs,
141
  )
142
  with torch.no_grad():
143
+ if TYPE_WRITER:
144
  for generation_output in model.stream_generate(
145
  input_ids=input_ids,
146
  generation_config=generation_config,