import openai import logging import time from templates.validation import ValidationTemplate from templates.itinerary import ItineraryTemplate from templates.mapping import MappingTemplate from langchain.chat_models import ChatOpenAI from langchain.chains import LLMChain, SequentialChain from datetime import datetime from datetime import date from dateutil.relativedelta import relativedelta logging.basicConfig(level=logging.INFO) class Agent: def __init__( self, openai_api_key, model="gpt-3.5-turbo", temperature=0, debug=True, ): self.logger = logging.getLogger(__name__) if debug: self.logger.setLevel(logging.DEBUG) else: self.logger.setLevel(logging.INFO) self._openai_key = openai_api_key # Initialize ChatOpenAI with the provided OpenAI API key and model details self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key) # Initialize the ValidationTemplate and ItineraryTemplate self.validation_prompt = ValidationTemplate() self.itinerary_prompt = ItineraryTemplate() self.mapping_prompt = MappingTemplate() # Setup the validation chain using the LLMChain and SequentialChain self.validation_chain = self._set_up_validation_chain(debug) self.itinerary_chain = self._set_up_itinerary_chain(debug) def _set_up_validation_chain(self, debug=True): # Make validation agent chain using LLMChain validation_agent = LLMChain( llm=self.chat_model, prompt=self.validation_prompt.chat_prompt, output_parser=self.validation_prompt.parser, output_key="validation_output", verbose=debug, ) # Add validation agent to sequential chain overall_chain = SequentialChain( chains=[validation_agent], input_variables=["query", "format_instructions"], output_variables=["validation_output"], verbose=debug, ) return overall_chain def _set_up_itinerary_chain(self, debug=True): # set up LLMChain to get the itinerary as a string itinerary_agent = LLMChain( llm=self.chat_model, prompt=self.itinerary_prompt.chat_prompt, verbose=debug, output_key="itinerary_suggestion", ) # set up LLMChain to extract the waypoints as a JSON object mapping_agent = LLMChain( llm=self.chat_model, prompt=self.mapping_prompt.chat_prompt, output_parser=self.mapping_prompt.parser, verbose=debug, output_key="mapping_list", ) # overall chain allows us to call the travel_agent and parser in # sequence, with labelled outputs. overall_chain = SequentialChain( chains=[itinerary_agent, mapping_agent], input_variables=["start_location", "end_location", "start_date", "end_date", "attractions", "budget", "transportation", "accommodation", "schedule", "format_instructions"], output_variables=["itinerary_suggestion","mapping_list"], verbose=debug, ) return overall_chain def validate_travel(self, query): self.logger.info("Validating query: %s", query) t1 = time.time() self.logger.info( "Calling validation (model is {}) on user input".format( self.chat_model.model_name ) ) # Call the validation chain with the query and format instructions validation_result = self.validation_chain.run( { "query": query, "format_instructions": self.validation_prompt.parser.get_format_instructions(), } ) self.logger.info("Datatype of validation_result: %s", type(validation_result)) # Convert the validation result into a dictionary if it's not one already if isinstance(validation_result, dict): validation_dict = validation_result else: # assuming validation_result is an instance of the Validation class validation_dict = validation_result.dict() # Log the datatype and content of the validation output self.logger.info("Datatype of validation_dict: %s", type(validation_dict)) self.logger.info("Content of validation_dict: %s", validation_dict) t2 = time.time() self.logger.debug("Time to validate request: %.2f seconds", t2 - t1) return validation_dict def calculate_duration(self, start_date, end_date): if not isinstance(start_date, date) or not isinstance(end_date, date): raise ValueError("start_date and end_date must be datetime.date objects") if end_date < start_date: raise ValueError("End date must be after or equal to start date") # Calculate the duration using relativedelta delta = relativedelta(end_date, start_date) years = delta.years months = delta.months days = delta.days + 1 # We'll calculate weeks from days duration_parts = [] if years > 0: duration_parts.append(f"{years} year{'s' if years > 1 else ''}") if months > 0: duration_parts.append(f"{months} month{'s' if months > 1 else ''}") weeks = days // 7 days = days % 7 if weeks > 0: duration_parts.append(f"{weeks} week{'s' if weeks > 1 else ''}") if days > 0: duration_parts.append(f"{days} day{'s' if days > 1 else ''}") return ', '.join(duration_parts) def generate_itinerary(self, user_details): self.logger.info("Generating itinerary for user details: %s", user_details) # Validate the user details dictionary keys match the expected input variables expected_keys = ["start_location", "end_location", "start_date", "end_date", "attractions", "budget", "transportation", "accommodation", "schedule"] for key in expected_keys: if key not in user_details: self.logger.error("Missing '%s' in user details.", key) return None # or handle the missing key appropriately try: # Calculate trip duration trip_duration = self.calculate_duration(user_details['start_date'], user_details['end_date']) # Construct the query phrase query_phrase = "{} trip from {} to {}".format(trip_duration, user_details['start_location'], user_details['end_location']) except KeyError as e: self.logger.error("Missing key in user details: %s", e) return None # or handle the missing key appropriately except ValueError as e: self.logger.error(e) return None # or handle the error appropriately t1 = time.time() self.logger.info("Calling itinerary chain to validate user query") validation_dict = self.validate_travel(query_phrase) is_plan_valid = validation_dict["plan_is_valid"] if is_plan_valid.lower() == "no": self.logger.warning("User request was not valid!") print("\n######\n Travel plan is not valid \n######\n") print(validation_result["updated_request"]) # Create a dictionary with variable names as keys result_dict = { "itinerary_suggestion": None, "list_of_places": None, "validation_dict": validation_dict } return result_dict self.logger.info("User query is valid. Calling itinerary chain on user details") itinerary_details = user_details.copy() itinerary_details["format_instructions"] = self.mapping_prompt.parser.get_format_instructions() # Call the itinerary chain with the itinerary details itinerary_result = self.itinerary_chain(itinerary_details) itinerary_suggestion = itinerary_result["itinerary_suggestion"] list_of_places = itinerary_result["mapping_list"].dict() # Log the datatype and content of the list_of_places output self.logger.info("Datatype of validation_dict: %s", type(list_of_places)) self.logger.info("Content of validation_dict: %s", list_of_places) t2 = time.time() self.logger.debug("Time to generate itinerary: %.2f seconds", t2 - t1) # Create a dictionary with variable names as keys result_dict = { "itinerary_suggestion": itinerary_suggestion, "list_of_places": list_of_places, "validation_dict": validation_dict } return result_dict