File size: 2,686 Bytes
8cb8290
 
 
 
a106116
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e1de02
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import asyncio
import re
import logging

from schema import Answer, Question
logger = logging.getLogger()
import re
from ingest import KURS_URL, DEFAULT_LANGUAGE
from langchain.callbacks import get_openai_callback

from config import State

COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315

def blocking_chain(chain, request):
    return chain(request, return_only_outputs=True)

async def question_handler(question: Question, state: State) -> Answer:
    question = question.question
    logger.info(f"Q: {question}")

    cost = 0
    with get_openai_callback() as cb:
        result = await asyncio.to_thread(blocking_chain, state.chain, {"question": question})
        cost = cb.total_cost
    logger.debug(f"result: {result}")

    answer = result['answer']
    logger.info(f"A: {answer}")
    
    if answer.startswith("I cannot help"):
        answer = "I'm sorry, " + answer
        return Answer(**{"answer": answer, "url": ""})
    
    sources = result.get('sources')
    logger.info(f"Sources: {sources}")
    if sources:
        sources = re.findall(COURSE_PATTERN, sources)
    elif "none of the sources" not in answer.lower():
        answer, sources = split_sources(answer)

    courses = [source.upper() for source in sources if state.course_exists(source)] # filter out courses that don't exist
    courses = set(courses)
    logger.info(f"unique courses: {courses}")

    urls = [KURS_URL.format(course_code=course, language=DEFAULT_LANGUAGE) for course in courses] # format into urls
    logger.info(f"urls: {urls}")

    answer = answer.strip().removesuffix("(").strip() 

    if (not answer or len(answer) < 3) and urls:
        answer = "Something went wrong, but I found a link."

    logging.info(f"Cost of query: ${'{0:.2g}'.format(cost)}")

    return Answer(answer=answer, urls=urls if urls else [])

def split_sources(answer: str):
    patterns = [
        "Sources", 
        "Source",
        "References",
        "Reference",
        "sources",
        "source",
        "SOURCE"
    ]
    for pattern in patterns:
        if pattern in answer:
            all_answers = answer.split(pattern)
            if len(all_answers) == 2:
                ans, sources = all_answers
                courses = re.findall(COURSE_PATTERN, sources)
            elif len(all_answers) > 2:
                ans = ""
                courses = []
                for i, a in enumerate(all_answers):
                    if i % 2 == 0:
                        ans += a
                    else:
                        courses = re.findall(COURSE_PATTERN, a)
                        courses.extend(courses)
            return ans, courses
        
    return answer, []