Akj2023 commited on
Commit
47c93d7
Β·
1 Parent(s): 5490f9d

Add validation prompt template, test and agent

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ __pycache__
agents/agent.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import logging
3
+ import time
4
+ from templates.validation import ValidationTemplate # Adjust the import path as necessary
5
+ from langchain.chat_models import ChatOpenAI
6
+ from langchain.chains import LLMChain, SequentialChain
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+
10
+ class Agent:
11
+ def __init__(
12
+ self,
13
+ openai_api_key,
14
+ model="gpt-3.5-turbo",
15
+ temperature=0,
16
+ debug=True,
17
+ ):
18
+ self.logger = logging.getLogger(__name__)
19
+ if debug:
20
+ self.logger.setLevel(logging.DEBUG)
21
+ else:
22
+ self.logger.setLevel(logging.INFO)
23
+
24
+ self._openai_key = openai_api_key
25
+
26
+ # Initialize ChatOpenAI with the provided OpenAI API key and model details
27
+ self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
28
+ # Initialize the ValidationTemplate
29
+ self.validation_prompt = ValidationTemplate()
30
+ # Setup the validation chain using the LLMChain and SequentialChain
31
+ self.validation_chain = self._set_up_validation_chain(debug)
32
+
33
+ def _set_up_validation_chain(self, debug=True):
34
+ # Make validation agent chain using LLMChain
35
+ validation_agent = LLMChain(
36
+ llm=self.chat_model,
37
+ prompt=self.validation_prompt.chat_prompt,
38
+ output_parser=self.validation_prompt.parser,
39
+ output_key="validation_output",
40
+ verbose=debug,
41
+ )
42
+
43
+ # Add validation agent to sequential chain
44
+ overall_chain = SequentialChain(
45
+ chains=[validation_agent],
46
+ input_variables=["query", "format_instructions"],
47
+ output_variables=["validation_output"],
48
+ verbose=debug,
49
+ )
50
+
51
+ return overall_chain
52
+
53
+ def validate_travel(self, query):
54
+ self.logger.info("Validating query: %s", query)
55
+ t1 = time.time()
56
+
57
+ self.logger.info(
58
+ "Calling validation (model is {}) on user input".format(
59
+ self.chat_model.model_name
60
+ )
61
+ )
62
+
63
+
64
+ # Call the validation chain with the query and format instructions
65
+ validation_result = self.validation_chain.run(
66
+ {
67
+ "query": query,
68
+ "format_instructions": self.validation_prompt.parser.get_format_instructions(),
69
+ }
70
+ )
71
+
72
+ # Extract the result from the validation output
73
+ validation_output = validation_result["validation_output"].dict()
74
+ t2 = time.time()
75
+ self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
76
+
77
+ return validation_output
api_integration/openai_integration.py DELETED
File without changes
app.py CHANGED
@@ -31,8 +31,11 @@ if 'end_date' not in st.session_state:
31
  with st.sidebar:
32
  st.header('Enter Your Travel Details')
33
 
34
- # Destination input
35
- destination = st.text_input('Destination', help='Enter the destination for your trip.')
 
 
 
36
 
37
  # Create the start date input widget
38
  start_date = st.date_input(
@@ -93,10 +96,11 @@ with st.sidebar:
93
  # Main page layout
94
  st.header('Your Itinerary')
95
  if submit:
96
- if destination and attractions and start_date and end_date:
97
  # The function to generate the itinerary would go here.
98
  # The following lines are placeholders to show the captured inputs.
99
- st.write('Destination:', destination)
 
100
  st.write('Travel Dates:', st.session_state['start_date'], 'to', st.session_state['end_date'])
101
  st.write('Attractions:', attractions)
102
  st.write('Budget:', budget)
 
31
  with st.sidebar:
32
  st.header('Enter Your Travel Details')
33
 
34
+ # Start location input
35
+ start_location = st.text_input('From', help='Enter the starting location for your trip.')
36
+
37
+ # End location input
38
+ end_location = st.text_input('To', help='Enter the final location for your trip.')
39
 
40
  # Create the start date input widget
41
  start_date = st.date_input(
 
96
  # Main page layout
97
  st.header('Your Itinerary')
98
  if submit:
99
+ if start_location and end_location and attractions and start_date and end_date:
100
  # The function to generate the itinerary would go here.
101
  # The following lines are placeholders to show the captured inputs.
102
+ st.write('From:', start_location)
103
+ st.write('To:', end_location)
104
  st.write('Travel Dates:', st.session_state['start_date'], 'to', st.session_state['end_date'])
105
  st.write('Attractions:', attractions)
106
  st.write('Budget:', budget)
config.py CHANGED
@@ -11,12 +11,18 @@ def load_secrets():
11
  google_maps_key = os.getenv("GOOGLE_MAPS_API_KEY")
12
  google_palm_key = os.getenv("GOOGLE_PALM_API_KEY")
13
 
14
- # Make sure to check if keys were successfully loaded and handle accordingly
15
- if not open_ai_key or not google_maps_key or not google_palm_key:
16
- raise Exception("API keys not found. Ensure .env file is set up correctly for local development or secrets are set for deployment.")
17
-
18
- return {
19
  "OPENAI_API_KEY": open_ai_key,
20
  "GOOGLE_MAPS_API_KEY": google_maps_key,
21
  "GOOGLE_PALM_API_KEY": google_palm_key,
22
  }
 
 
 
 
 
 
 
 
 
 
11
  google_maps_key = os.getenv("GOOGLE_MAPS_API_KEY")
12
  google_palm_key = os.getenv("GOOGLE_PALM_API_KEY")
13
 
14
+ # Collect all keys in a dictionary
15
+ secrets = {
 
 
 
16
  "OPENAI_API_KEY": open_ai_key,
17
  "GOOGLE_MAPS_API_KEY": google_maps_key,
18
  "GOOGLE_PALM_API_KEY": google_palm_key,
19
  }
20
+
21
+ # Check if any of the keys are missing
22
+ missing_keys = [key for key, value in secrets.items() if not value]
23
+ if missing_keys:
24
+ missing_keys_str = ", ".join(missing_keys)
25
+ raise Exception(f"Missing API keys: {missing_keys_str}. Ensure .env file is set up correctly.")
26
+
27
+ return secrets
28
+
requirements.txt CHANGED
@@ -1,15 +1,14 @@
1
  streamlit
2
- requests>=2.26.0
3
- pandas>=1.3.5
4
- folium>=0.14.0
5
- streamlit-folium>=0.4.0
6
- python-dotenv>=1.0.0
7
- httpx>=0.21.1
8
  langchain
9
  python-dotenv
10
  openai
11
  folium
12
  google-generativeai
13
- geopandas>=0.13.2
14
  tiktoken
15
  duckduckgo-search
 
1
  streamlit
2
+ requests
3
+ folium
4
+ streamlit-folium
5
+ python-dotenv
6
+ httpx
 
7
  langchain
8
  python-dotenv
9
  openai
10
  folium
11
  google-generativeai
12
+ geopandas
13
  tiktoken
14
  duckduckgo-search
{api_integration β†’ templates}/__init__.py RENAMED
File without changes
templates/validation.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import (
2
+ ChatPromptTemplate,
3
+ SystemMessagePromptTemplate,
4
+ HumanMessagePromptTemplate,
5
+ )
6
+ from langchain.output_parsers import PydanticOutputParser
7
+ from pydantic import BaseModel, Field
8
+
9
+ # Validation schema
10
+ class Validation(BaseModel):
11
+ plan_is_valid: str = Field(description="This field is 'yes' if the plan is feasible, 'no' otherwise")
12
+ updated_request: str = Field(description="Your update to the plan")
13
+
14
+ class ValidationTemplate:
15
+ def __init__(self):
16
+ self.system_template = """
17
+ You are a travel agent who helps users make exciting travel plans.
18
+
19
+ The user's request will be denoted by four hashtags. Determine if the user's
20
+ request is reasonable and achievable within the constraints they set.
21
+
22
+ A valid request should contain the following:
23
+ - A start and end location
24
+ - A trip duration that is reasonable given the start and end location
25
+ - Some other details, like the user's interests and/or preferred mode of transport
26
+
27
+ Any request that contains potentially harmful activities is not valid, regardless of what
28
+ other details are provided.
29
+
30
+ If the request is not valid, set
31
+ plan_is_valid = 'no' and use your travel expertise to update the request to make it valid,
32
+ keeping your revised request shorter than 100 words.
33
+
34
+ If the request seems reasonable, then set plan_is_valid = 'yes' and
35
+ don't revise the request.
36
+
37
+ {format_instructions}
38
+ """
39
+
40
+ self.human_template = """
41
+ ####{query}####
42
+ """
43
+
44
+ self.parser = PydanticOutputParser(pydantic_object=Validation)
45
+
46
+ self.system_message_prompt = SystemMessagePromptTemplate.from_template(
47
+ self.system_template,
48
+ partial_variables={
49
+ "format_instructions": self.parser.get_format_instructions()
50
+ },
51
+ )
52
+ self.human_message_prompt = HumanMessagePromptTemplate.from_template(
53
+ self.human_template,
54
+ input_variables=["query"]
55
+ )
56
+
57
+ self.chat_prompt = ChatPromptTemplate.from_messages(
58
+ [self.system_message_prompt, self.human_message_prompt]
59
+ )
api_integration/google_maps_integration.py β†’ tests/__init__.py RENAMED
File without changes
api_integration/google_palm_integration.py β†’ tests/agent/__init__.py RENAMED
File without changes
tests/agent/test_agent.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from config import load_secrets # Update this path to match the actual location
3
+ from agents.agent import Agent # Update this path to match the actual location
4
+
5
+ DEBUG = True
6
+
7
+ class TestAgentMethods(unittest.TestCase):
8
+
9
+ def assert_secrets(self, secrets_dict):
10
+ assert secrets_dict["OPENAI_API_KEY"] is not None
11
+ assert secrets_dict["GOOGLE_MAPS_API_KEY"] is not None
12
+ assert secrets_dict["GOOGLE_PALM_API_KEY"] is not None
13
+
14
+
15
+ def setUp(self):
16
+ self.debug = DEBUG
17
+ secrets = load_secrets()
18
+ self.assert_secrets(secrets)
19
+
20
+ # Assuming you only need the OPENAI_API_KEY for this test
21
+ self.agent = Agent(
22
+ openai_api_key=secrets["OPENAI_API_KEY"],
23
+ debug=self.debug,
24
+ )
25
+
26
+ # @unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
27
+ def test_validation_chain(self):
28
+ validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
29
+
30
+ # not a reasonable request
31
+ q1 = "fly to the moon"
32
+ q1_res = validation_chain(
33
+ {
34
+ "query": q1,
35
+ "format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
36
+ }
37
+ )
38
+ q1_out = q1_res["validation_output"].dict()
39
+ self.assertEqual(q1_out["plan_is_valid"], "no")
40
+
41
+ # not a reasonable request
42
+ q2 = "1 day road trip from Chicago to Brazilia"
43
+ q2_res = validation_chain(
44
+ {
45
+ "query": q2,
46
+ "format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
47
+ }
48
+ )
49
+ q2_out = q2_res["validation_output"].dict()
50
+ self.assertEqual(q2_out["plan_is_valid"], "no")
51
+
52
+ # a reasonable request
53
+ q3 = "1 week road trip from Chicago to Mexico city"
54
+ q3_res = validation_chain(
55
+ {
56
+ "query": q3,
57
+ "format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
58
+ }
59
+ )
60
+ q3_out = q3_res["validation_output"].dict()
61
+ self.assertEqual(q3_out["plan_is_valid"], "yes")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ unittest.main()