HaileyStorm
commited on
Commit
•
1230db0
1
Parent(s):
1401d32
Fixed early-stopping in get_mamba_response based on space/dot tokens (now decodes the strings instead of using hardcoded token ids).
Browse files
chess-gpt-eval/mamba_module.py
CHANGED
@@ -81,6 +81,8 @@ class MambaPlayer:
|
|
81 |
self.vocab_size = vocab_size
|
82 |
self.encode = encode
|
83 |
self.decode = decode
|
|
|
|
|
84 |
self.model = model
|
85 |
self.ctx = ctx
|
86 |
self.device = device
|
@@ -107,8 +109,9 @@ class MambaPlayer:
|
|
107 |
|
108 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
109 |
next_token_id = torch.multinomial(probs, num_samples=1)
|
110 |
-
if
|
111 |
-
|
|
|
112 |
else:
|
113 |
have_non_space = True
|
114 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
|
|
81 |
self.vocab_size = vocab_size
|
82 |
self.encode = encode
|
83 |
self.decode = decode
|
84 |
+
self.space_tok = encode(' ')[0]
|
85 |
+
self.dot_tok = encode('.')[0]
|
86 |
self.model = model
|
87 |
self.ctx = ctx
|
88 |
self.device = device
|
|
|
109 |
|
110 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
111 |
next_token_id = torch.multinomial(probs, num_samples=1)
|
112 |
+
if next_token_id == self.space_tok or next_token_id==self.dot_tok:
|
113 |
+
if have_non_space:
|
114 |
+
break
|
115 |
else:
|
116 |
have_non_space = True
|
117 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|