File size: 4,648 Bytes
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a106116
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e75140
8cb8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
logger = logging.getLogger()

import openai
from pydantic import BaseSettings

from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains import SequentialChain
from langchain.llms import OpenAI
from langchain.chains import LLMCheckerChain
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.vectorstores import Pinecone
from langchain.embeddings.openai import OpenAIEmbeddings

from magic.prompts import PROMPT, EXAMPLE_PROMPT
from magic.self_query_retriever import SelfQueryRetriever

from utils import get_courses

import pinecone


class Settings(BaseSettings):
    OPENAI_API_KEY: str = 'OPENAI_API_KEY'
    OPENAI_CHAT_MODEL: str = 'gpt-3.5-turbo'
    PINECONE_API_KEY: str = 'PINECONE_API_KEY'
    PINECONE_INDEX_NAME: str = 'kth-qa'
    PINECONE_ENV: str = 'us-west1-gcp-free'
    class Config:
        env_file = '.env'

def set_openai_key(key):
    """Sets OpenAI key."""
    openai.api_key = key

class State:
    settings: Settings
    store: Pinecone
    chain: RetrievalQAWithSourcesChain
    courses: list

    def __init__(self):
        self.settings = Settings()

        self.courses = get_courses()

        # OPENAI
        set_openai_key(self.settings.OPENAI_API_KEY)
        
        # PINECONE VECTORSTORE
        embeddings = OpenAIEmbeddings()
        pinecone.init(api_key=self.settings.PINECONE_API_KEY, environment=self.settings.PINECONE_ENV)
        self.store: Pinecone = Pinecone.from_existing_index(self.settings.PINECONE_INDEX_NAME, embeddings, "text")
        logger.info(f"Pinecone store initialized")

        # CHAINS
        doc_chain = self._load_doc_chain()
        qa_chain = self._load_qa_chain(doc_chain, self_query=True)
        
        # JUST QA
        self.chain = qa_chain
        
        # SEQ CHAIN with QA and CHECKER
        # checker_chain = self._load_checker_chain()
        # self.chain = self._load_seq_chain([qa_chain, checker_chain])

    def _load_seq_chain(self, chains):
        sequential_chain = SequentialChain(
            chains=chains,
            input_variables=["question"],
            output_variables=["answer"],
            verbose=True)
        return sequential_chain

    def _load_checker_chain(self):
        llm = OpenAI(temperature=0)
        checker_chain = LLMCheckerChain(llm=llm, verbose=True, input_key="answer", output_key="result")
        return checker_chain
        
    def _load_doc_chain(self):
        doc_chain = load_qa_with_sources_chain(
            ChatOpenAI(temperature=0, max_tokens=256, model=self.settings.OPENAI_CHAT_MODEL, request_timeout=120),
            chain_type="stuff",
            document_variable_name="context",
            prompt=PROMPT,
            document_prompt=EXAMPLE_PROMPT
        )
        return doc_chain
    
    def _load_qa_chain(self, doc_chain, self_query=False):
        """Load QA chain with retriever.
        If self_query is True, the retriever will be a SelfQueryRetriever,
        which will extract a metadata filter from question, and add to the vectorstore query.
        """
        if self_query:
            metadata_field_info=[
                AttributeInfo(
                    name="course",
                    description="A course code for a course", 
                    type="string"
                )]
            document_content_description = "Brief description of a course"
            llm = OpenAI(temperature=0, model_name='text-davinci-002')
            retriever = SelfQueryRetriever.from_llm(llm, self.store, document_content_description, 
                                                    metadata_field_info, verbose=True)
            qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
                                                    retriever=retriever,
                                                    return_source_documents=False)
        else:
            qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain, 
                        retriever=self.store.as_retriever(),
                        return_source_documents=False)
        return qa_chain
  
    def course_exists(self, course: str):
        course = course.upper()
        exists = course in self.courses
        if exists:
            logger.info(f'Course {course} exists')
            return True
        else:
            logger.info(f'Course {course} does not exist')
            return False
    
if __name__ == '__main__':
    state = State()
    print(state.settings)