green-city-finder / src /augmentation /prompt_generation.py
Ashmi Banerjee
updates to the s-fairness calculation and refactoring code duplication
ac20456
raw
history blame
7.21 kB
from information_retrieval import info_retrieval as ir
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
def generate_prompt(query, context, template=None):
"""
Function that generates the prompt given the user query and retrieved context. A specific prompt template will be
used if provided, otherwise the default base_prompt template is used.
Args:
- query: str
- context: list[dict]
- template: str
"""
if template:
SYS_PROMPT = template
else:
SYS_PROMPT = """You are an AI recommendation system. Your task is to recommend cities in Europe for travel
based on the user's question. You should use the provided contexts to suggest a list of the 3 best cities
that are best suited to the user's question, as well as the month of travel. If the user has already provided
the month of travel in the question, use the same month; otherwise, provide the ideal month of travel. Each
recommendation should also contain an explanation of why it is being recommended, based on the context. Your
answer must begin with "I recommend " followed by the city name and why you recommended it. Your answers are
correct, high-quality, and written by a domain expert. If the provided context does not contain the answer,
simply state, "The provided context does not have the answer." """
USER_PROMPT = """ Question: {} Which city do you recommend and why?
Context: Here are the options: {}
Answer:
"""
formatted_prompt = f"{USER_PROMPT.format(query, context)}"
messages = [
{"role": "system", "content": SYS_PROMPT},
{"role": "user", "content": formatted_prompt}
]
return messages
def format_context(context):
"""
Function that formats the retrieved context in a way that is easy for the LLM to understand.
Args:
- context: list[dict]; retrieved context
"""
formatted_context = ''
for i, (city, info) in enumerate(context.items()):
text = f"Option {i + 1}: {city} is a city in {info['country']}."
info_text = f"Here is some information about the city. {info['text']}"
attractions_text = "Here are some attractions: "
att_flag = 0
restaurants_text = "Here are some places to eat/drink: "
rest_flag = 0
hotels_text = "Here are some hotels: "
hotel_flag = 0
if len(info['listings']):
for listing in info['listings']:
if listing['type'] in ['see', 'do', 'go', 'view']:
att_flag = 1
attractions_text += f"{listing['name']} ({listing['description']}), "
elif listing['type'] in ['eat', 'drink']:
rest_flag = 1
restaurants_text += f"{listing['name']} ({listing['description']}), "
else:
hotel_flag = 1
hotels_text += f"{listing['name']} ({listing['description']}), "
# If we add sustainability in the end then it could get truncated because of context window
if "sustainability" in info:
if info['sustainability']['month'] == 'No data available':
sfairness_text = "This city has no sustainability (or s-fairness) score available."
else:
sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n "
text += sfairness_text
text += info_text
if att_flag:
text += f"\n{attractions_text}"
if rest_flag:
text += f"\n{restaurants_text}"
if hotel_flag:
text += f"\n{hotels_text}"
formatted_context += text + "\n\n "
return formatted_context
def augment_prompt(query, context, **params):
"""
Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
retrieved context, which can be passed to the LLM.
Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested
dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness
scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets
the limit of results and whether to rerank the results
"""
# what about the cities without s-fairness scores? i.e. they don't have seasonality data
prompt_with_sustainability = """You are an AI recommendation system. Your task is to recommend cities in Europe
for travel based on the user's question. You should use the provided contexts to suggest the city that is best
suited to the user's question. You recommend a list of the top 3 most sustainable cities to the user, as well as
the best month of travel. Each recommendation should also contain an explanation of why it is being recommended,
on sustainability grounds based on the context. The context contains a sustainability score for each city,
also known as the s-fairness score, along with the ideal month of travel. A lower s-fairness score indicates that
the city is a better destination for the month provided. A city without a sustainability score should not be
considered. You should only consider the s-fairness score while choosing the best city. However, your answer
should not contain the numeric score itself or any mention of the sustainability score. Your answer must begin
with "I recommend " followed by the city names and why you recommended it. Your answers are correct, high-quality,
and written by a domain expert. If the provided context does not contain the answer, simply state, "The provided
context does not have the answer. """
# format context
formatted_context = format_context(context)
if "sustainability" in params["params"] and params["params"]["sustainability"]:
prompt = generate_prompt(query, formatted_context, prompt_with_sustainability)
else:
prompt = generate_prompt(query, formatted_context)
return prompt
def test():
context_params = {
'limit': 3,
'reranking': 0,
'sustainability': 0
}
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
"in winter. "
# without sustainability
context = ir.get_context(query, **context_params)
without_sfairness = augment_prompt(
query=query,
context=context,
params=context_params
)
# with sustainability
context_params.update({'sustainability': 1})
s_context = ir.get_context(query, **context_params)
# s_formatted_context = format_context(s_context)
with_sfairness = augment_prompt(
query=query,
context=s_context,
params=context_params
)
return with_sfairness
if __name__ == "__main__":
prompt = test()
print(prompt)