Update src/model/CHATGPT.py
Browse files- src/model/CHATGPT.py +6 -0
src/model/CHATGPT.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import json
|
2 |
import os
|
|
|
3 |
from typing import Any, Dict, List, Tuple, Union
|
4 |
|
5 |
import numpy as np
|
@@ -287,6 +288,7 @@ class CHATGPT:
|
|
287 |
Returns:
|
288 |
Generated response and updated state.
|
289 |
"""
|
|
|
290 |
conv_dict["context"].append(options[0])
|
291 |
generated_inputs, generated_response = self.get_conv(conv_dict)
|
292 |
options_letter = list(options[1].keys())
|
@@ -312,6 +314,10 @@ class CHATGPT:
|
|
312 |
# response = (
|
313 |
# options[1].get(choice, {}).get("template", generated_response)
|
314 |
# )
|
|
|
|
|
|
|
|
|
315 |
response = generated_response
|
316 |
|
317 |
# Update the state. Hack: penalize the choice to reduce the
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from copy import deepcopy
|
4 |
from typing import Any, Dict, List, Tuple, Union
|
5 |
|
6 |
import numpy as np
|
|
|
288 |
Returns:
|
289 |
Generated response and updated state.
|
290 |
"""
|
291 |
+
initial_conv_dict = deepcopy(conv_dict)
|
292 |
conv_dict["context"].append(options[0])
|
293 |
generated_inputs, generated_response = self.get_conv(conv_dict)
|
294 |
options_letter = list(options[1].keys())
|
|
|
314 |
# response = (
|
315 |
# options[1].get(choice, {}).get("template", generated_response)
|
316 |
# )
|
317 |
+
|
318 |
+
# Generate response with original context otherwise generated
|
319 |
+
# response is the option's letter.
|
320 |
+
_, generated_response = self.get_conv(initial_conv_dict)
|
321 |
response = generated_response
|
322 |
|
323 |
# Update the state. Hack: penalize the choice to reduce the
|