Trickshotblaster commited on
Commit
8de17f5
·
1 Parent(s): a4e819a

inference mode go brr

Browse files
Files changed (2) hide show
  1. __pycache__/gpt.cpython-310.pyc +0 -0
  2. gpt.py +24 -23
__pycache__/gpt.cpython-310.pyc CHANGED
Binary files a/__pycache__/gpt.cpython-310.pyc and b/__pycache__/gpt.cpython-310.pyc differ
 
gpt.py CHANGED
@@ -137,26 +137,27 @@ my_GPT.eval()
137
  eot = enc._special_tokens['<|endoftext|>']
138
 
139
  def get_response(in_text, top_k=50, temperature=1):
140
- prompt = "USER: " + in_text + "\nASSISTANT: "
141
- input_tokens = enc.encode(prompt)
142
- output_tokens = enc.encode(prompt)
143
- for x in range(block_size):
144
- if len(input_tokens) > block_size:
145
- input_tokens = input_tokens[1:]
146
- context_tensor = torch.tensor(input_tokens).view(1, -1).to(device)
147
-
148
- logits, loss = my_GPT(context_tensor)
149
- logits = logits[:, -1, :] / temperature
150
- if top_k > 0:
151
- # Remove all tokens with a probability less than the last token of the top-k
152
- indices_to_remove = logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None]
153
- logits[indices_to_remove] = float("-inf")
154
- probs = F.softmax(logits, dim=-1)
155
- result = torch.multinomial(probs, num_samples=1).item()
156
- if result == eot:
157
- break
158
- input_tokens.append(result)
159
- output_tokens.append(result)
160
- yield enc.decode(output_tokens)
161
-
162
- yield enc.decode(output_tokens)
 
 
137
  eot = enc._special_tokens['<|endoftext|>']
138
 
139
  def get_response(in_text, top_k=50, temperature=1):
140
+ with torch.inference_mode():
141
+ prompt = "USER: " + in_text + "\nASSISTANT: "
142
+ input_tokens = enc.encode(prompt)
143
+ output_tokens = enc.encode(prompt)
144
+ for x in range(block_size):
145
+ if len(input_tokens) > block_size:
146
+ input_tokens = input_tokens[1:]
147
+ context_tensor = torch.tensor(input_tokens).view(1, -1).to(device)
148
+
149
+ logits, loss = my_GPT(context_tensor)
150
+ logits = logits[:, -1, :] / temperature
151
+ if top_k > 0:
152
+ # Remove all tokens with a probability less than the last token of the top-k
153
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None]
154
+ logits[indices_to_remove] = float("-inf")
155
+ probs = F.softmax(logits, dim=-1)
156
+ result = torch.multinomial(probs, num_samples=1).item()
157
+ if result == eot:
158
+ break
159
+ input_tokens.append(result)
160
+ output_tokens.append(result)
161
+ yield enc.decode(output_tokens)
162
+
163
+ yield enc.decode(output_tokens)