File size: 2,820 Bytes
d26280a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import os.path

from pathlib import Path
from typing import List, Optional, Union

from langchain.schema import AIMessage, HumanMessage, SystemMessage

from gpt_engineer.core.ai import AI
from gpt_engineer.core.token_usage import TokenUsageLog

# Type hint for a chat message
Message = Union[AIMessage, HumanMessage, SystemMessage]


class CachingAI(AI):
    def __init__(self, *args, **kwargs):
        self.temperature = 0.1
        self.azure_endpoint = ""
        self.streaming = False
        try:
            self.model_name = "gpt-4-1106-preview"
            self.llm = self._create_chat_model()
        except:  # Catch anything
            self.model_name = "cached_response_model"
            self.llm = None
        self.streaming = False
        self.token_usage_log = TokenUsageLog("gpt-4-1106-preview")
        self.cache_file = Path(__file__).parent / "ai_cache.json"

    def next(
        self,
        messages: List[Message],
        prompt: Optional[str] = None,
        *,
        step_name: str,
    ) -> List[Message]:
        """
        Advances the conversation by sending message history
        to LLM and updating with the response.

        Parameters
        ----------
        messages : List[Message]
            The list of messages in the conversation.
        prompt : Optional[str], optional
            The prompt to use, by default None.
        step_name : str
            The name of the step.

        Returns
        -------
        List[Message]
            The updated list of messages in the conversation.
        """
        """
        Advances the conversation by sending message history
        to LLM and updating with the response.
        """
        if prompt:
            messages.append(HumanMessage(content=prompt))

        # read cache file if it exists
        if os.path.isfile(self.cache_file):
            with open(self.cache_file, "r") as cache_file:
                cache = json.load(cache_file)
        else:
            cache = dict()

        messages_key = self.serialize_messages(messages)
        if messages_key not in cache:
            callbacks = []
            print("calling backoff inference")
            response = self.backoff_inference(messages, callbacks)
            self.token_usage_log.update_log(
                messages=messages, answer=response.content, step_name=step_name
            )
            print("called backoff inference")
            print("cost in usd:", self.token_usage_log.usage_cost())

            messages.append(response)
            cache[messages_key] = self.serialize_messages(messages)
            with open(self.cache_file, "w") as cache_file:
                json.dump(cache, cache_file)
                cache_file.write("\n")

        return self.deserialize_messages(cache[messages_key])