|
import asyncio |
|
import re |
|
import logging |
|
|
|
from schema import Answer, Question |
|
logger = logging.getLogger() |
|
import re |
|
from ingest import KURS_URL, DEFAULT_LANGUAGE |
|
from langchain.callbacks import get_openai_callback |
|
|
|
from config import State |
|
|
|
COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" |
|
|
|
def blocking_chain(chain, request): |
|
return chain(request, return_only_outputs=True) |
|
|
|
async def question_handler(question: Question, state: State) -> Answer: |
|
question = question.question |
|
logger.info(f"Q: {question}") |
|
|
|
cost = 0 |
|
with get_openai_callback() as cb: |
|
result = await asyncio.to_thread(blocking_chain, state.chain, {"question": question}) |
|
cost = cb.total_cost |
|
logger.debug(f"result: {result}") |
|
|
|
answer = result['answer'] |
|
logger.info(f"A: {answer}") |
|
|
|
if answer.startswith("I cannot help"): |
|
answer = "I'm sorry, " + answer |
|
return Answer(**{"answer": answer, "url": ""}) |
|
|
|
sources = result.get('sources') |
|
logger.info(f"Sources: {sources}") |
|
if sources: |
|
sources = re.findall(COURSE_PATTERN, sources) |
|
elif "none of the sources" not in answer.lower(): |
|
answer, sources = split_sources(answer) |
|
|
|
courses = [source.upper() for source in sources if state.course_exists(source)] |
|
courses = set(courses) |
|
logger.info(f"unique courses: {courses}") |
|
|
|
urls = [KURS_URL.format(course_code=course, language=DEFAULT_LANGUAGE) for course in courses] |
|
logger.info(f"urls: {urls}") |
|
|
|
answer = answer.strip().removesuffix("(").strip() |
|
|
|
if (not answer or len(answer) < 3) and urls: |
|
answer = "Something went wrong, but I found a link." |
|
|
|
logging.info(f"Cost of query: ${'{0:.2g}'.format(cost)}") |
|
|
|
return Answer(answer=answer, urls=urls if urls else []) |
|
|
|
def split_sources(answer: str): |
|
patterns = [ |
|
"Sources", |
|
"Source", |
|
"References", |
|
"Reference", |
|
"sources", |
|
"source", |
|
"SOURCE" |
|
] |
|
for pattern in patterns: |
|
if pattern in answer: |
|
all_answers = answer.split(pattern) |
|
if len(all_answers) == 2: |
|
ans, sources = all_answers |
|
courses = re.findall(COURSE_PATTERN, sources) |
|
elif len(all_answers) > 2: |
|
ans = "" |
|
courses = [] |
|
for i, a in enumerate(all_answers): |
|
if i % 2 == 0: |
|
ans += a |
|
else: |
|
courses = re.findall(COURSE_PATTERN, a) |
|
courses.extend(courses) |
|
return ans, courses |
|
|
|
return answer, [] |
|
|