MasaakiKotera commited on
Commit
8f79412
1 Parent(s): 106d320

Upload sampling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sampling.py +54 -16
sampling.py CHANGED
@@ -23,6 +23,7 @@ parser.add_argument("--out_path", type=str, required=True)
23
  parser.add_argument("--num_samples", type=int, required=False, default=100000)
24
  parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
25
  parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
 
26
  parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
27
  parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
28
  parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
@@ -30,6 +31,7 @@ parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a t
30
  parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
31
  parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
32
  parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
 
33
 
34
  args = parser.parse_args()
35
  init_from = args.init_from
@@ -37,17 +39,20 @@ out_path = args.out_path
37
  num_samples = args.num_samples
38
  max_new_tokens = args.max_new_tokens
39
  strategy = args.strategy
 
 
40
  temperature = args.temperature
41
  top_k = args.top_k
42
  ckpt_path = args.ckpt_path
43
  tokenizer_path = args.tokenizer_path
44
  start = args.start
45
  repetition_penalty = args.repetition_penalty
 
 
46
 
47
  # -----------------------------------------------------------------------------
48
  seed = random.randint(1,6666)
49
- # seed = 1337
50
- device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
51
  dtype = 'float32'
52
  # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
53
  compile = False # use PyTorch 2.0 to compile the model to be faster
@@ -91,20 +96,53 @@ load_meta = False
91
  encode = tokenizer.encode
92
  decode = tokenizer.decode
93
 
94
- start_ids = encode("".join(start))
95
- x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
97
 
98
- with open(out_path, 'a') as f:
99
- with torch.no_grad():
100
- with ctx:
101
- for k in tqdm(range(num_samples), desc="Generating samples"):
102
- token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
103
-
104
- # Shuffle tokens if --shuffle_token is specified
105
- if args.shuffle_token:
106
- random.shuffle(token_sequence)
107
 
108
- y = decode(token_sequence) + '\n'
109
- f.write(y)
110
- f.flush()
 
 
23
  parser.add_argument("--num_samples", type=int, required=False, default=100000)
24
  parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
25
  parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
26
+ parser.add_argument("--beam_size",type=int, required=False,default=3,help="beam size for beam search")
27
  parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
28
  parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
29
  parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
 
31
  parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
32
  parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
33
  parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
34
+ parser.add_argument("--fasta", action='store_true', default=True, help="Enable writing output in FASTA format")
35
 
36
  args = parser.parse_args()
37
  init_from = args.init_from
 
39
  num_samples = args.num_samples
40
  max_new_tokens = args.max_new_tokens
41
  strategy = args.strategy
42
+ assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
43
+ beam_size = args.beam_size
44
  temperature = args.temperature
45
  top_k = args.top_k
46
  ckpt_path = args.ckpt_path
47
  tokenizer_path = args.tokenizer_path
48
  start = args.start
49
  repetition_penalty = args.repetition_penalty
50
+ fasta = args.fasta
51
+
52
 
53
  # -----------------------------------------------------------------------------
54
  seed = random.randint(1,6666)
55
+ device = 'cuda'
 
56
  dtype = 'float32'
57
  # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
58
  compile = False # use PyTorch 2.0 to compile the model to be faster
 
96
  encode = tokenizer.encode
97
  decode = tokenizer.decode
98
 
99
+ fasta_out_path = os.path.splitext(out_path)[0] + ".fasta" if fasta else None
100
+
101
+ if strategy in["sampling", "top_k"]:
102
+ start_ids = encode("".join(start))
103
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
104
+
105
+
106
+ with open(out_path, 'a') as f:
107
+ with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f:
108
+ with torch.no_grad():
109
+ with ctx:
110
+ for k in tqdm(range(num_samples), desc="Generating samples"):
111
+ token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
112
+
113
+ # Shuffle tokens if --shuffle_token is specified
114
+ if args.shuffle_token:
115
+ random.shuffle(token_sequence)
116
+
117
+ y = decode(token_sequence).replace(' ', '')
118
+ # y = decode(token_sequence).replace('\n', '').replace(' ', '') + '\n'
119
+ f.write(y)
120
+ f.flush()
121
+
122
+
123
+ if fasta:
124
+ fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n"
125
+ fasta_f.write(fasta_entry.strip() + '\n')
126
+ fasta_f.flush()
127
+
128
+
129
+ elif strategy in ["beam_search", "greedy_search"]:
130
+ with open(out_path, 'a') as f:
131
+ with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f:
132
+ with torch.no_grad():
133
+ with ctx:
134
+ start = '<|endoftext|>'
135
+ start_ids = encode(start)
136
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
137
+
138
+ token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, beam_size=beam_size)[0].tolist()
139
 
140
+ y = decode(token_sequence).replace(' ', '')
141
+ f.write(y)
142
+ f.flush()
143
 
 
 
 
 
 
 
 
 
 
144
 
145
+ if fasta:
146
+ fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n"
147
+ fasta_f.write(fasta_entry.strip() + '\n')
148
+ fasta_f.flush()