Spaces:
Runtime error
Runtime error
Add validation prompt template, test and agent
Browse files- .gitignore +2 -0
- agents/agent.py +77 -0
- api_integration/openai_integration.py +0 -0
- app.py +8 -4
- config.py +11 -5
- requirements.txt +6 -7
- {api_integration β templates}/__init__.py +0 -0
- templates/validation.py +59 -0
- api_integration/google_maps_integration.py β tests/__init__.py +0 -0
- api_integration/google_palm_integration.py β tests/agent/__init__.py +0 -0
- tests/agent/test_agent.py +65 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
__pycache__
|
agents/agent.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from templates.validation import ValidationTemplate # Adjust the import path as necessary
|
5 |
+
from langchain.chat_models import ChatOpenAI
|
6 |
+
from langchain.chains import LLMChain, SequentialChain
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
|
10 |
+
class Agent:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
openai_api_key,
|
14 |
+
model="gpt-3.5-turbo",
|
15 |
+
temperature=0,
|
16 |
+
debug=True,
|
17 |
+
):
|
18 |
+
self.logger = logging.getLogger(__name__)
|
19 |
+
if debug:
|
20 |
+
self.logger.setLevel(logging.DEBUG)
|
21 |
+
else:
|
22 |
+
self.logger.setLevel(logging.INFO)
|
23 |
+
|
24 |
+
self._openai_key = openai_api_key
|
25 |
+
|
26 |
+
# Initialize ChatOpenAI with the provided OpenAI API key and model details
|
27 |
+
self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
|
28 |
+
# Initialize the ValidationTemplate
|
29 |
+
self.validation_prompt = ValidationTemplate()
|
30 |
+
# Setup the validation chain using the LLMChain and SequentialChain
|
31 |
+
self.validation_chain = self._set_up_validation_chain(debug)
|
32 |
+
|
33 |
+
def _set_up_validation_chain(self, debug=True):
|
34 |
+
# Make validation agent chain using LLMChain
|
35 |
+
validation_agent = LLMChain(
|
36 |
+
llm=self.chat_model,
|
37 |
+
prompt=self.validation_prompt.chat_prompt,
|
38 |
+
output_parser=self.validation_prompt.parser,
|
39 |
+
output_key="validation_output",
|
40 |
+
verbose=debug,
|
41 |
+
)
|
42 |
+
|
43 |
+
# Add validation agent to sequential chain
|
44 |
+
overall_chain = SequentialChain(
|
45 |
+
chains=[validation_agent],
|
46 |
+
input_variables=["query", "format_instructions"],
|
47 |
+
output_variables=["validation_output"],
|
48 |
+
verbose=debug,
|
49 |
+
)
|
50 |
+
|
51 |
+
return overall_chain
|
52 |
+
|
53 |
+
def validate_travel(self, query):
|
54 |
+
self.logger.info("Validating query: %s", query)
|
55 |
+
t1 = time.time()
|
56 |
+
|
57 |
+
self.logger.info(
|
58 |
+
"Calling validation (model is {}) on user input".format(
|
59 |
+
self.chat_model.model_name
|
60 |
+
)
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
# Call the validation chain with the query and format instructions
|
65 |
+
validation_result = self.validation_chain.run(
|
66 |
+
{
|
67 |
+
"query": query,
|
68 |
+
"format_instructions": self.validation_prompt.parser.get_format_instructions(),
|
69 |
+
}
|
70 |
+
)
|
71 |
+
|
72 |
+
# Extract the result from the validation output
|
73 |
+
validation_output = validation_result["validation_output"].dict()
|
74 |
+
t2 = time.time()
|
75 |
+
self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
|
76 |
+
|
77 |
+
return validation_output
|
api_integration/openai_integration.py
DELETED
File without changes
|
app.py
CHANGED
@@ -31,8 +31,11 @@ if 'end_date' not in st.session_state:
|
|
31 |
with st.sidebar:
|
32 |
st.header('Enter Your Travel Details')
|
33 |
|
34 |
-
#
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
# Create the start date input widget
|
38 |
start_date = st.date_input(
|
@@ -93,10 +96,11 @@ with st.sidebar:
|
|
93 |
# Main page layout
|
94 |
st.header('Your Itinerary')
|
95 |
if submit:
|
96 |
-
if
|
97 |
# The function to generate the itinerary would go here.
|
98 |
# The following lines are placeholders to show the captured inputs.
|
99 |
-
st.write('
|
|
|
100 |
st.write('Travel Dates:', st.session_state['start_date'], 'to', st.session_state['end_date'])
|
101 |
st.write('Attractions:', attractions)
|
102 |
st.write('Budget:', budget)
|
|
|
31 |
with st.sidebar:
|
32 |
st.header('Enter Your Travel Details')
|
33 |
|
34 |
+
# Start location input
|
35 |
+
start_location = st.text_input('From', help='Enter the starting location for your trip.')
|
36 |
+
|
37 |
+
# End location input
|
38 |
+
end_location = st.text_input('To', help='Enter the final location for your trip.')
|
39 |
|
40 |
# Create the start date input widget
|
41 |
start_date = st.date_input(
|
|
|
96 |
# Main page layout
|
97 |
st.header('Your Itinerary')
|
98 |
if submit:
|
99 |
+
if start_location and end_location and attractions and start_date and end_date:
|
100 |
# The function to generate the itinerary would go here.
|
101 |
# The following lines are placeholders to show the captured inputs.
|
102 |
+
st.write('From:', start_location)
|
103 |
+
st.write('To:', end_location)
|
104 |
st.write('Travel Dates:', st.session_state['start_date'], 'to', st.session_state['end_date'])
|
105 |
st.write('Attractions:', attractions)
|
106 |
st.write('Budget:', budget)
|
config.py
CHANGED
@@ -11,12 +11,18 @@ def load_secrets():
|
|
11 |
google_maps_key = os.getenv("GOOGLE_MAPS_API_KEY")
|
12 |
google_palm_key = os.getenv("GOOGLE_PALM_API_KEY")
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
raise Exception("API keys not found. Ensure .env file is set up correctly for local development or secrets are set for deployment.")
|
17 |
-
|
18 |
-
return {
|
19 |
"OPENAI_API_KEY": open_ai_key,
|
20 |
"GOOGLE_MAPS_API_KEY": google_maps_key,
|
21 |
"GOOGLE_PALM_API_KEY": google_palm_key,
|
22 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
google_maps_key = os.getenv("GOOGLE_MAPS_API_KEY")
|
12 |
google_palm_key = os.getenv("GOOGLE_PALM_API_KEY")
|
13 |
|
14 |
+
# Collect all keys in a dictionary
|
15 |
+
secrets = {
|
|
|
|
|
|
|
16 |
"OPENAI_API_KEY": open_ai_key,
|
17 |
"GOOGLE_MAPS_API_KEY": google_maps_key,
|
18 |
"GOOGLE_PALM_API_KEY": google_palm_key,
|
19 |
}
|
20 |
+
|
21 |
+
# Check if any of the keys are missing
|
22 |
+
missing_keys = [key for key, value in secrets.items() if not value]
|
23 |
+
if missing_keys:
|
24 |
+
missing_keys_str = ", ".join(missing_keys)
|
25 |
+
raise Exception(f"Missing API keys: {missing_keys_str}. Ensure .env file is set up correctly.")
|
26 |
+
|
27 |
+
return secrets
|
28 |
+
|
requirements.txt
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
streamlit
|
2 |
-
requests
|
3 |
-
|
4 |
-
folium
|
5 |
-
|
6 |
-
|
7 |
-
httpx>=0.21.1
|
8 |
langchain
|
9 |
python-dotenv
|
10 |
openai
|
11 |
folium
|
12 |
google-generativeai
|
13 |
-
geopandas
|
14 |
tiktoken
|
15 |
duckduckgo-search
|
|
|
1 |
streamlit
|
2 |
+
requests
|
3 |
+
folium
|
4 |
+
streamlit-folium
|
5 |
+
python-dotenv
|
6 |
+
httpx
|
|
|
7 |
langchain
|
8 |
python-dotenv
|
9 |
openai
|
10 |
folium
|
11 |
google-generativeai
|
12 |
+
geopandas
|
13 |
tiktoken
|
14 |
duckduckgo-search
|
{api_integration β templates}/__init__.py
RENAMED
File without changes
|
templates/validation.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import (
|
2 |
+
ChatPromptTemplate,
|
3 |
+
SystemMessagePromptTemplate,
|
4 |
+
HumanMessagePromptTemplate,
|
5 |
+
)
|
6 |
+
from langchain.output_parsers import PydanticOutputParser
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
|
9 |
+
# Validation schema
|
10 |
+
class Validation(BaseModel):
|
11 |
+
plan_is_valid: str = Field(description="This field is 'yes' if the plan is feasible, 'no' otherwise")
|
12 |
+
updated_request: str = Field(description="Your update to the plan")
|
13 |
+
|
14 |
+
class ValidationTemplate:
|
15 |
+
def __init__(self):
|
16 |
+
self.system_template = """
|
17 |
+
You are a travel agent who helps users make exciting travel plans.
|
18 |
+
|
19 |
+
The user's request will be denoted by four hashtags. Determine if the user's
|
20 |
+
request is reasonable and achievable within the constraints they set.
|
21 |
+
|
22 |
+
A valid request should contain the following:
|
23 |
+
- A start and end location
|
24 |
+
- A trip duration that is reasonable given the start and end location
|
25 |
+
- Some other details, like the user's interests and/or preferred mode of transport
|
26 |
+
|
27 |
+
Any request that contains potentially harmful activities is not valid, regardless of what
|
28 |
+
other details are provided.
|
29 |
+
|
30 |
+
If the request is not valid, set
|
31 |
+
plan_is_valid = 'no' and use your travel expertise to update the request to make it valid,
|
32 |
+
keeping your revised request shorter than 100 words.
|
33 |
+
|
34 |
+
If the request seems reasonable, then set plan_is_valid = 'yes' and
|
35 |
+
don't revise the request.
|
36 |
+
|
37 |
+
{format_instructions}
|
38 |
+
"""
|
39 |
+
|
40 |
+
self.human_template = """
|
41 |
+
####{query}####
|
42 |
+
"""
|
43 |
+
|
44 |
+
self.parser = PydanticOutputParser(pydantic_object=Validation)
|
45 |
+
|
46 |
+
self.system_message_prompt = SystemMessagePromptTemplate.from_template(
|
47 |
+
self.system_template,
|
48 |
+
partial_variables={
|
49 |
+
"format_instructions": self.parser.get_format_instructions()
|
50 |
+
},
|
51 |
+
)
|
52 |
+
self.human_message_prompt = HumanMessagePromptTemplate.from_template(
|
53 |
+
self.human_template,
|
54 |
+
input_variables=["query"]
|
55 |
+
)
|
56 |
+
|
57 |
+
self.chat_prompt = ChatPromptTemplate.from_messages(
|
58 |
+
[self.system_message_prompt, self.human_message_prompt]
|
59 |
+
)
|
api_integration/google_maps_integration.py β tests/__init__.py
RENAMED
File without changes
|
api_integration/google_palm_integration.py β tests/agent/__init__.py
RENAMED
File without changes
|
tests/agent/test_agent.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from config import load_secrets # Update this path to match the actual location
|
3 |
+
from agents.agent import Agent # Update this path to match the actual location
|
4 |
+
|
5 |
+
DEBUG = True
|
6 |
+
|
7 |
+
class TestAgentMethods(unittest.TestCase):
|
8 |
+
|
9 |
+
def assert_secrets(self, secrets_dict):
|
10 |
+
assert secrets_dict["OPENAI_API_KEY"] is not None
|
11 |
+
assert secrets_dict["GOOGLE_MAPS_API_KEY"] is not None
|
12 |
+
assert secrets_dict["GOOGLE_PALM_API_KEY"] is not None
|
13 |
+
|
14 |
+
|
15 |
+
def setUp(self):
|
16 |
+
self.debug = DEBUG
|
17 |
+
secrets = load_secrets()
|
18 |
+
self.assert_secrets(secrets)
|
19 |
+
|
20 |
+
# Assuming you only need the OPENAI_API_KEY for this test
|
21 |
+
self.agent = Agent(
|
22 |
+
openai_api_key=secrets["OPENAI_API_KEY"],
|
23 |
+
debug=self.debug,
|
24 |
+
)
|
25 |
+
|
26 |
+
# @unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
|
27 |
+
def test_validation_chain(self):
|
28 |
+
validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
|
29 |
+
|
30 |
+
# not a reasonable request
|
31 |
+
q1 = "fly to the moon"
|
32 |
+
q1_res = validation_chain(
|
33 |
+
{
|
34 |
+
"query": q1,
|
35 |
+
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
|
36 |
+
}
|
37 |
+
)
|
38 |
+
q1_out = q1_res["validation_output"].dict()
|
39 |
+
self.assertEqual(q1_out["plan_is_valid"], "no")
|
40 |
+
|
41 |
+
# not a reasonable request
|
42 |
+
q2 = "1 day road trip from Chicago to Brazilia"
|
43 |
+
q2_res = validation_chain(
|
44 |
+
{
|
45 |
+
"query": q2,
|
46 |
+
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
|
47 |
+
}
|
48 |
+
)
|
49 |
+
q2_out = q2_res["validation_output"].dict()
|
50 |
+
self.assertEqual(q2_out["plan_is_valid"], "no")
|
51 |
+
|
52 |
+
# a reasonable request
|
53 |
+
q3 = "1 week road trip from Chicago to Mexico city"
|
54 |
+
q3_res = validation_chain(
|
55 |
+
{
|
56 |
+
"query": q3,
|
57 |
+
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
|
58 |
+
}
|
59 |
+
)
|
60 |
+
q3_out = q3_res["validation_output"].dict()
|
61 |
+
self.assertEqual(q3_out["plan_is_valid"], "yes")
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
unittest.main()
|