File size: 4,648 Bytes
8cb8290 a106116 8cb8290 6e75140 8cb8290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import logging
logger = logging.getLogger()
import openai
from pydantic import BaseSettings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains import SequentialChain
from langchain.llms import OpenAI
from langchain.chains import LLMCheckerChain
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.vectorstores import Pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from magic.prompts import PROMPT, EXAMPLE_PROMPT
from magic.self_query_retriever import SelfQueryRetriever
from utils import get_courses
import pinecone
class Settings(BaseSettings):
OPENAI_API_KEY: str = 'OPENAI_API_KEY'
OPENAI_CHAT_MODEL: str = 'gpt-3.5-turbo'
PINECONE_API_KEY: str = 'PINECONE_API_KEY'
PINECONE_INDEX_NAME: str = 'kth-qa'
PINECONE_ENV: str = 'us-west1-gcp-free'
class Config:
env_file = '.env'
def set_openai_key(key):
"""Sets OpenAI key."""
openai.api_key = key
class State:
settings: Settings
store: Pinecone
chain: RetrievalQAWithSourcesChain
courses: list
def __init__(self):
self.settings = Settings()
self.courses = get_courses()
# OPENAI
set_openai_key(self.settings.OPENAI_API_KEY)
# PINECONE VECTORSTORE
embeddings = OpenAIEmbeddings()
pinecone.init(api_key=self.settings.PINECONE_API_KEY, environment=self.settings.PINECONE_ENV)
self.store: Pinecone = Pinecone.from_existing_index(self.settings.PINECONE_INDEX_NAME, embeddings, "text")
logger.info(f"Pinecone store initialized")
# CHAINS
doc_chain = self._load_doc_chain()
qa_chain = self._load_qa_chain(doc_chain, self_query=True)
# JUST QA
self.chain = qa_chain
# SEQ CHAIN with QA and CHECKER
# checker_chain = self._load_checker_chain()
# self.chain = self._load_seq_chain([qa_chain, checker_chain])
def _load_seq_chain(self, chains):
sequential_chain = SequentialChain(
chains=chains,
input_variables=["question"],
output_variables=["answer"],
verbose=True)
return sequential_chain
def _load_checker_chain(self):
llm = OpenAI(temperature=0)
checker_chain = LLMCheckerChain(llm=llm, verbose=True, input_key="answer", output_key="result")
return checker_chain
def _load_doc_chain(self):
doc_chain = load_qa_with_sources_chain(
ChatOpenAI(temperature=0, max_tokens=256, model=self.settings.OPENAI_CHAT_MODEL, request_timeout=120),
chain_type="stuff",
document_variable_name="context",
prompt=PROMPT,
document_prompt=EXAMPLE_PROMPT
)
return doc_chain
def _load_qa_chain(self, doc_chain, self_query=False):
"""Load QA chain with retriever.
If self_query is True, the retriever will be a SelfQueryRetriever,
which will extract a metadata filter from question, and add to the vectorstore query.
"""
if self_query:
metadata_field_info=[
AttributeInfo(
name="course",
description="A course code for a course",
type="string"
)]
document_content_description = "Brief description of a course"
llm = OpenAI(temperature=0, model_name='text-davinci-002')
retriever = SelfQueryRetriever.from_llm(llm, self.store, document_content_description,
metadata_field_info, verbose=True)
qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
retriever=retriever,
return_source_documents=False)
else:
qa_chain = RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain,
retriever=self.store.as_retriever(),
return_source_documents=False)
return qa_chain
def course_exists(self, course: str):
course = course.upper()
exists = course in self.courses
if exists:
logger.info(f'Course {course} exists')
return True
else:
logger.info(f'Course {course} does not exist')
return False
if __name__ == '__main__':
state = State()
print(state.settings)
|