File size: 8,973 Bytes
47c93d7
 
 
d137f7e
 
 
47c93d7
 
d137f7e
 
 
47c93d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d137f7e
 
47c93d7
d137f7e
 
 
 
47c93d7
 
d137f7e
47c93d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d137f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47c93d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d137f7e
 
 
 
 
 
 
 
 
 
 
 
47c93d7
 
 
d137f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import openai
import logging
import time
from templates.validation import ValidationTemplate 
from templates.itinerary import ItineraryTemplate
from templates.mapping import MappingTemplate 
from langchain.chat_models  import ChatOpenAI
from langchain.chains import LLMChain, SequentialChain
from datetime import datetime
from datetime import date
from dateutil.relativedelta import relativedelta

logging.basicConfig(level=logging.INFO)

class Agent:
    def __init__(
        self,
        openai_api_key,
        model="gpt-3.5-turbo",
        temperature=0,
        debug=True,
    ):
        self.logger = logging.getLogger(__name__)
        if debug:
            self.logger.setLevel(logging.DEBUG)
        else:
            self.logger.setLevel(logging.INFO)
        
        self._openai_key = openai_api_key

        # Initialize ChatOpenAI with the provided OpenAI API key and model details
        self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)

        # Initialize the ValidationTemplate and ItineraryTemplate
        self.validation_prompt = ValidationTemplate()
        self.itinerary_prompt = ItineraryTemplate()
        self.mapping_prompt = MappingTemplate()


        # Setup the validation chain using the LLMChain and SequentialChain
        self.validation_chain = self._set_up_validation_chain(debug)
        self.itinerary_chain = self._set_up_itinerary_chain(debug)

    def _set_up_validation_chain(self, debug=True):
        # Make validation agent chain using LLMChain
        validation_agent = LLMChain(
            llm=self.chat_model,
            prompt=self.validation_prompt.chat_prompt,
            output_parser=self.validation_prompt.parser,
            output_key="validation_output",
            verbose=debug,
        )
        
        # Add validation agent to sequential chain
        overall_chain = SequentialChain(
            chains=[validation_agent],
            input_variables=["query", "format_instructions"],
            output_variables=["validation_output"],
            verbose=debug,
        )

        return overall_chain
    
    def _set_up_itinerary_chain(self, debug=True):

        # set up LLMChain to get the itinerary as a string
        itinerary_agent  = LLMChain(
            llm=self.chat_model,
            prompt=self.itinerary_prompt.chat_prompt,
            verbose=debug,
            output_key="itinerary_suggestion",
        )

        # set up LLMChain to extract the waypoints as a JSON object
        mapping_agent = LLMChain(
            llm=self.chat_model,
            prompt=self.mapping_prompt.chat_prompt,
            output_parser=self.mapping_prompt.parser,
            verbose=debug,
            output_key="mapping_list",
        )

        # overall chain allows us to call the travel_agent and parser in
        # sequence, with labelled outputs.
        overall_chain = SequentialChain(
            chains=[itinerary_agent, mapping_agent],
            input_variables=["start_location", "end_location", "start_date", "end_date",
                             "attractions", "budget", "transportation", "accommodation",
                             "schedule", "format_instructions"],
            output_variables=["itinerary_suggestion","mapping_list"],
            verbose=debug,
        )

        return overall_chain

    def validate_travel(self, query):
        self.logger.info("Validating query: %s", query)
        t1 = time.time()
        
        self.logger.info(
            "Calling validation (model is {}) on user input".format(
                self.chat_model.model_name
            )
        )

        # Call the validation chain with the query and format instructions
        validation_result = self.validation_chain.run(
            {
                "query": query,
                "format_instructions": self.validation_prompt.parser.get_format_instructions(),
            }
        )

        self.logger.info("Datatype of validation_result: %s", type(validation_result))

        # Convert the validation result into a dictionary if it's not one already
        if isinstance(validation_result, dict):
            validation_dict = validation_result
        else:  # assuming validation_result is an instance of the Validation class
            validation_dict = validation_result.dict()

        # Log the datatype and content of the validation output
        self.logger.info("Datatype of validation_dict: %s", type(validation_dict))
        self.logger.info("Content of validation_dict: %s", validation_dict)

        t2 = time.time()
        self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)

        return validation_dict


    def calculate_duration(self, start_date, end_date):
        if not isinstance(start_date, date) or not isinstance(end_date, date):
            raise ValueError("start_date and end_date must be datetime.date objects")

        if end_date < start_date:
            raise ValueError("End date must be after or equal to start date")

        # Calculate the duration using relativedelta
        delta = relativedelta(end_date, start_date)
        
        years = delta.years
        months = delta.months
        days = delta.days + 1 # We'll calculate weeks from days

        duration_parts = []

        if years > 0:
            duration_parts.append(f"{years} year{'s' if years > 1 else ''}")

        if months > 0:
            duration_parts.append(f"{months} month{'s' if months > 1 else ''}")

        weeks = days // 7
        days = days % 7

        if weeks > 0:
            duration_parts.append(f"{weeks} week{'s' if weeks > 1 else ''}")

        if days > 0:
            duration_parts.append(f"{days} day{'s' if days > 1 else ''}")

        return ', '.join(duration_parts)


    def generate_itinerary(self, user_details):
        self.logger.info("Generating itinerary for user details: %s", user_details)
        
        # Validate the user details dictionary keys match the expected input variables
        expected_keys = ["start_location", "end_location", "start_date", "end_date",
                        "attractions", "budget", "transportation", "accommodation",
                        "schedule"]
        for key in expected_keys:
            if key not in user_details:
                self.logger.error("Missing '%s' in user details.", key)
                return None  # or handle the missing key appropriately
        
        try:
            # Calculate trip duration
            trip_duration = self.calculate_duration(user_details['start_date'], user_details['end_date'])
            # Construct the query phrase
            query_phrase = "{} trip from {} to {}".format(trip_duration, user_details['start_location'], user_details['end_location'])
        except KeyError as e:
            self.logger.error("Missing key in user details: %s", e)
            return None  # or handle the missing key appropriately
        except ValueError as e:
            self.logger.error(e)
            return None  # or handle the error appropriately

        t1 = time.time()
        
        self.logger.info("Calling itinerary chain to validate user query")
        validation_dict = self.validate_travel(query_phrase)
        is_plan_valid = validation_dict["plan_is_valid"]

        
        if is_plan_valid.lower() == "no":
            self.logger.warning("User request was not valid!")
            print("\n######\n Travel plan is not valid \n######\n")
            print(validation_result["updated_request"])

            # Create a dictionary with variable names as keys
            result_dict = {
                "itinerary_suggestion": None,
                "list_of_places": None,
                "validation_dict": validation_dict
            }
            
            return result_dict

        self.logger.info("User query is valid. Calling itinerary chain on user details")

        itinerary_details = user_details.copy()
        itinerary_details["format_instructions"] = self.mapping_prompt.parser.get_format_instructions()

        # Call the itinerary chain with the itinerary details
        itinerary_result = self.itinerary_chain(itinerary_details)

        itinerary_suggestion = itinerary_result["itinerary_suggestion"]
        list_of_places = itinerary_result["mapping_list"].dict()

        # Log the datatype and content of the list_of_places output
        self.logger.info("Datatype of validation_dict: %s", type(list_of_places))
        self.logger.info("Content of validation_dict: %s", list_of_places)

        t2 = time.time()
        self.logger.debug("Time to generate itinerary: %.2f seconds", t2 - t1)
        
        # Create a dictionary with variable names as keys
        result_dict = {
            "itinerary_suggestion": itinerary_suggestion,
            "list_of_places": list_of_places,
            "validation_dict": validation_dict
        }

        return result_dict