Akj2023's picture
Improve UI | Add chains for itinerary with map JSON
d137f7e
import openai
import logging
import time
from templates.validation import ValidationTemplate
from templates.itinerary import ItineraryTemplate
from templates.mapping import MappingTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain, SequentialChain
from datetime import datetime
from datetime import date
from dateutil.relativedelta import relativedelta
logging.basicConfig(level=logging.INFO)
class Agent:
def __init__(
self,
openai_api_key,
model="gpt-3.5-turbo",
temperature=0,
debug=True,
):
self.logger = logging.getLogger(__name__)
if debug:
self.logger.setLevel(logging.DEBUG)
else:
self.logger.setLevel(logging.INFO)
self._openai_key = openai_api_key
# Initialize ChatOpenAI with the provided OpenAI API key and model details
self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
# Initialize the ValidationTemplate and ItineraryTemplate
self.validation_prompt = ValidationTemplate()
self.itinerary_prompt = ItineraryTemplate()
self.mapping_prompt = MappingTemplate()
# Setup the validation chain using the LLMChain and SequentialChain
self.validation_chain = self._set_up_validation_chain(debug)
self.itinerary_chain = self._set_up_itinerary_chain(debug)
def _set_up_validation_chain(self, debug=True):
# Make validation agent chain using LLMChain
validation_agent = LLMChain(
llm=self.chat_model,
prompt=self.validation_prompt.chat_prompt,
output_parser=self.validation_prompt.parser,
output_key="validation_output",
verbose=debug,
)
# Add validation agent to sequential chain
overall_chain = SequentialChain(
chains=[validation_agent],
input_variables=["query", "format_instructions"],
output_variables=["validation_output"],
verbose=debug,
)
return overall_chain
def _set_up_itinerary_chain(self, debug=True):
# set up LLMChain to get the itinerary as a string
itinerary_agent = LLMChain(
llm=self.chat_model,
prompt=self.itinerary_prompt.chat_prompt,
verbose=debug,
output_key="itinerary_suggestion",
)
# set up LLMChain to extract the waypoints as a JSON object
mapping_agent = LLMChain(
llm=self.chat_model,
prompt=self.mapping_prompt.chat_prompt,
output_parser=self.mapping_prompt.parser,
verbose=debug,
output_key="mapping_list",
)
# overall chain allows us to call the travel_agent and parser in
# sequence, with labelled outputs.
overall_chain = SequentialChain(
chains=[itinerary_agent, mapping_agent],
input_variables=["start_location", "end_location", "start_date", "end_date",
"attractions", "budget", "transportation", "accommodation",
"schedule", "format_instructions"],
output_variables=["itinerary_suggestion","mapping_list"],
verbose=debug,
)
return overall_chain
def validate_travel(self, query):
self.logger.info("Validating query: %s", query)
t1 = time.time()
self.logger.info(
"Calling validation (model is {}) on user input".format(
self.chat_model.model_name
)
)
# Call the validation chain with the query and format instructions
validation_result = self.validation_chain.run(
{
"query": query,
"format_instructions": self.validation_prompt.parser.get_format_instructions(),
}
)
self.logger.info("Datatype of validation_result: %s", type(validation_result))
# Convert the validation result into a dictionary if it's not one already
if isinstance(validation_result, dict):
validation_dict = validation_result
else: # assuming validation_result is an instance of the Validation class
validation_dict = validation_result.dict()
# Log the datatype and content of the validation output
self.logger.info("Datatype of validation_dict: %s", type(validation_dict))
self.logger.info("Content of validation_dict: %s", validation_dict)
t2 = time.time()
self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
return validation_dict
def calculate_duration(self, start_date, end_date):
if not isinstance(start_date, date) or not isinstance(end_date, date):
raise ValueError("start_date and end_date must be datetime.date objects")
if end_date < start_date:
raise ValueError("End date must be after or equal to start date")
# Calculate the duration using relativedelta
delta = relativedelta(end_date, start_date)
years = delta.years
months = delta.months
days = delta.days + 1 # We'll calculate weeks from days
duration_parts = []
if years > 0:
duration_parts.append(f"{years} year{'s' if years > 1 else ''}")
if months > 0:
duration_parts.append(f"{months} month{'s' if months > 1 else ''}")
weeks = days // 7
days = days % 7
if weeks > 0:
duration_parts.append(f"{weeks} week{'s' if weeks > 1 else ''}")
if days > 0:
duration_parts.append(f"{days} day{'s' if days > 1 else ''}")
return ', '.join(duration_parts)
def generate_itinerary(self, user_details):
self.logger.info("Generating itinerary for user details: %s", user_details)
# Validate the user details dictionary keys match the expected input variables
expected_keys = ["start_location", "end_location", "start_date", "end_date",
"attractions", "budget", "transportation", "accommodation",
"schedule"]
for key in expected_keys:
if key not in user_details:
self.logger.error("Missing '%s' in user details.", key)
return None # or handle the missing key appropriately
try:
# Calculate trip duration
trip_duration = self.calculate_duration(user_details['start_date'], user_details['end_date'])
# Construct the query phrase
query_phrase = "{} trip from {} to {}".format(trip_duration, user_details['start_location'], user_details['end_location'])
except KeyError as e:
self.logger.error("Missing key in user details: %s", e)
return None # or handle the missing key appropriately
except ValueError as e:
self.logger.error(e)
return None # or handle the error appropriately
t1 = time.time()
self.logger.info("Calling itinerary chain to validate user query")
validation_dict = self.validate_travel(query_phrase)
is_plan_valid = validation_dict["plan_is_valid"]
if is_plan_valid.lower() == "no":
self.logger.warning("User request was not valid!")
print("\n######\n Travel plan is not valid \n######\n")
print(validation_result["updated_request"])
# Create a dictionary with variable names as keys
result_dict = {
"itinerary_suggestion": None,
"list_of_places": None,
"validation_dict": validation_dict
}
return result_dict
self.logger.info("User query is valid. Calling itinerary chain on user details")
itinerary_details = user_details.copy()
itinerary_details["format_instructions"] = self.mapping_prompt.parser.get_format_instructions()
# Call the itinerary chain with the itinerary details
itinerary_result = self.itinerary_chain(itinerary_details)
itinerary_suggestion = itinerary_result["itinerary_suggestion"]
list_of_places = itinerary_result["mapping_list"].dict()
# Log the datatype and content of the list_of_places output
self.logger.info("Datatype of validation_dict: %s", type(list_of_places))
self.logger.info("Content of validation_dict: %s", list_of_places)
t2 = time.time()
self.logger.debug("Time to generate itinerary: %.2f seconds", t2 - t1)
# Create a dictionary with variable names as keys
result_dict = {
"itinerary_suggestion": itinerary_suggestion,
"list_of_places": list_of_places,
"validation_dict": validation_dict
}
return result_dict