Spaces:
Runtime error
Runtime error
import openai | |
import logging | |
import time | |
from templates.validation import ValidationTemplate | |
from templates.itinerary import ItineraryTemplate | |
from templates.mapping import MappingTemplate | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import LLMChain, SequentialChain | |
from datetime import datetime | |
from datetime import date | |
from dateutil.relativedelta import relativedelta | |
logging.basicConfig(level=logging.INFO) | |
class Agent: | |
def __init__( | |
self, | |
openai_api_key, | |
model="gpt-3.5-turbo", | |
temperature=0, | |
debug=True, | |
): | |
self.logger = logging.getLogger(__name__) | |
if debug: | |
self.logger.setLevel(logging.DEBUG) | |
else: | |
self.logger.setLevel(logging.INFO) | |
self._openai_key = openai_api_key | |
# Initialize ChatOpenAI with the provided OpenAI API key and model details | |
self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key) | |
# Initialize the ValidationTemplate and ItineraryTemplate | |
self.validation_prompt = ValidationTemplate() | |
self.itinerary_prompt = ItineraryTemplate() | |
self.mapping_prompt = MappingTemplate() | |
# Setup the validation chain using the LLMChain and SequentialChain | |
self.validation_chain = self._set_up_validation_chain(debug) | |
self.itinerary_chain = self._set_up_itinerary_chain(debug) | |
def _set_up_validation_chain(self, debug=True): | |
# Make validation agent chain using LLMChain | |
validation_agent = LLMChain( | |
llm=self.chat_model, | |
prompt=self.validation_prompt.chat_prompt, | |
output_parser=self.validation_prompt.parser, | |
output_key="validation_output", | |
verbose=debug, | |
) | |
# Add validation agent to sequential chain | |
overall_chain = SequentialChain( | |
chains=[validation_agent], | |
input_variables=["query", "format_instructions"], | |
output_variables=["validation_output"], | |
verbose=debug, | |
) | |
return overall_chain | |
def _set_up_itinerary_chain(self, debug=True): | |
# set up LLMChain to get the itinerary as a string | |
itinerary_agent = LLMChain( | |
llm=self.chat_model, | |
prompt=self.itinerary_prompt.chat_prompt, | |
verbose=debug, | |
output_key="itinerary_suggestion", | |
) | |
# set up LLMChain to extract the waypoints as a JSON object | |
mapping_agent = LLMChain( | |
llm=self.chat_model, | |
prompt=self.mapping_prompt.chat_prompt, | |
output_parser=self.mapping_prompt.parser, | |
verbose=debug, | |
output_key="mapping_list", | |
) | |
# overall chain allows us to call the travel_agent and parser in | |
# sequence, with labelled outputs. | |
overall_chain = SequentialChain( | |
chains=[itinerary_agent, mapping_agent], | |
input_variables=["start_location", "end_location", "start_date", "end_date", | |
"attractions", "budget", "transportation", "accommodation", | |
"schedule", "format_instructions"], | |
output_variables=["itinerary_suggestion","mapping_list"], | |
verbose=debug, | |
) | |
return overall_chain | |
def validate_travel(self, query): | |
self.logger.info("Validating query: %s", query) | |
t1 = time.time() | |
self.logger.info( | |
"Calling validation (model is {}) on user input".format( | |
self.chat_model.model_name | |
) | |
) | |
# Call the validation chain with the query and format instructions | |
validation_result = self.validation_chain.run( | |
{ | |
"query": query, | |
"format_instructions": self.validation_prompt.parser.get_format_instructions(), | |
} | |
) | |
self.logger.info("Datatype of validation_result: %s", type(validation_result)) | |
# Convert the validation result into a dictionary if it's not one already | |
if isinstance(validation_result, dict): | |
validation_dict = validation_result | |
else: # assuming validation_result is an instance of the Validation class | |
validation_dict = validation_result.dict() | |
# Log the datatype and content of the validation output | |
self.logger.info("Datatype of validation_dict: %s", type(validation_dict)) | |
self.logger.info("Content of validation_dict: %s", validation_dict) | |
t2 = time.time() | |
self.logger.debug("Time to validate request: %.2f seconds", t2 - t1) | |
return validation_dict | |
def calculate_duration(self, start_date, end_date): | |
if not isinstance(start_date, date) or not isinstance(end_date, date): | |
raise ValueError("start_date and end_date must be datetime.date objects") | |
if end_date < start_date: | |
raise ValueError("End date must be after or equal to start date") | |
# Calculate the duration using relativedelta | |
delta = relativedelta(end_date, start_date) | |
years = delta.years | |
months = delta.months | |
days = delta.days + 1 # We'll calculate weeks from days | |
duration_parts = [] | |
if years > 0: | |
duration_parts.append(f"{years} year{'s' if years > 1 else ''}") | |
if months > 0: | |
duration_parts.append(f"{months} month{'s' if months > 1 else ''}") | |
weeks = days // 7 | |
days = days % 7 | |
if weeks > 0: | |
duration_parts.append(f"{weeks} week{'s' if weeks > 1 else ''}") | |
if days > 0: | |
duration_parts.append(f"{days} day{'s' if days > 1 else ''}") | |
return ', '.join(duration_parts) | |
def generate_itinerary(self, user_details): | |
self.logger.info("Generating itinerary for user details: %s", user_details) | |
# Validate the user details dictionary keys match the expected input variables | |
expected_keys = ["start_location", "end_location", "start_date", "end_date", | |
"attractions", "budget", "transportation", "accommodation", | |
"schedule"] | |
for key in expected_keys: | |
if key not in user_details: | |
self.logger.error("Missing '%s' in user details.", key) | |
return None # or handle the missing key appropriately | |
try: | |
# Calculate trip duration | |
trip_duration = self.calculate_duration(user_details['start_date'], user_details['end_date']) | |
# Construct the query phrase | |
query_phrase = "{} trip from {} to {}".format(trip_duration, user_details['start_location'], user_details['end_location']) | |
except KeyError as e: | |
self.logger.error("Missing key in user details: %s", e) | |
return None # or handle the missing key appropriately | |
except ValueError as e: | |
self.logger.error(e) | |
return None # or handle the error appropriately | |
t1 = time.time() | |
self.logger.info("Calling itinerary chain to validate user query") | |
validation_dict = self.validate_travel(query_phrase) | |
is_plan_valid = validation_dict["plan_is_valid"] | |
if is_plan_valid.lower() == "no": | |
self.logger.warning("User request was not valid!") | |
print("\n######\n Travel plan is not valid \n######\n") | |
print(validation_result["updated_request"]) | |
# Create a dictionary with variable names as keys | |
result_dict = { | |
"itinerary_suggestion": None, | |
"list_of_places": None, | |
"validation_dict": validation_dict | |
} | |
return result_dict | |
self.logger.info("User query is valid. Calling itinerary chain on user details") | |
itinerary_details = user_details.copy() | |
itinerary_details["format_instructions"] = self.mapping_prompt.parser.get_format_instructions() | |
# Call the itinerary chain with the itinerary details | |
itinerary_result = self.itinerary_chain(itinerary_details) | |
itinerary_suggestion = itinerary_result["itinerary_suggestion"] | |
list_of_places = itinerary_result["mapping_list"].dict() | |
# Log the datatype and content of the list_of_places output | |
self.logger.info("Datatype of validation_dict: %s", type(list_of_places)) | |
self.logger.info("Content of validation_dict: %s", list_of_places) | |
t2 = time.time() | |
self.logger.debug("Time to generate itinerary: %.2f seconds", t2 - t1) | |
# Create a dictionary with variable names as keys | |
result_dict = { | |
"itinerary_suggestion": itinerary_suggestion, | |
"list_of_places": list_of_places, | |
"validation_dict": validation_dict | |
} | |
return result_dict |