bstraehle commited on
Commit
d1d84e5
·
1 Parent(s): eceefb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -97
app.py CHANGED
@@ -2,119 +2,22 @@ import gradio as gr
2
  import openai, os, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
- from langchain.chains import LLMChain, RetrievalQA
6
- from langchain.chat_models import ChatOpenAI
7
- from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
- from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
9
- from langchain.document_loaders.generic import GenericLoader
10
- from langchain.document_loaders.parsers import OpenAIWhisperParser
11
- from langchain.embeddings.openai import OpenAIEmbeddings
12
- from langchain.prompts import PromptTemplate
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from langchain.vectorstores import Chroma
15
- from langchain.vectorstores import MongoDBAtlasVectorSearch
16
- from pymongo import MongoClient
17
 
18
  from rag import llm_chain, rag_chain
19
  from trace import wandb_trace
20
 
21
  _ = load_dotenv(find_dotenv())
22
 
23
- PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
24
- WEB_URL = "https://openai.com/research/gpt-4"
25
- YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
26
- YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
27
- YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
28
-
29
- YOUTUBE_DIR = "/data/youtube"
30
- CHROMA_DIR = "/data/chroma"
31
-
32
- MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
33
- MONGODB_DB_NAME = "langchain_db"
34
- MONGODB_COLLECTION_NAME = "gpt-4"
35
- MONGODB_INDEX_NAME = "default"
36
-
37
- LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
38
- RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
39
-
40
  RAG_OFF = "Off"
41
  RAG_CHROMA = "Chroma"
42
  RAG_MONGODB = "MongoDB"
43
 
44
- client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
45
- collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
46
-
47
  config = {
48
  "chunk_overlap": 150,
49
  "chunk_size": 1500,
50
  "k": 3,
51
- "model_name": "gpt-4-0613",
52
- "temperature": 0,
53
  }
54
 
55
- def document_loading_splitting():
56
- # Document loading
57
- docs = []
58
-
59
- # Load PDF
60
- loader = PyPDFLoader(PDF_URL)
61
- docs.extend(loader.load())
62
-
63
- # Load Web
64
- loader = WebBaseLoader(WEB_URL)
65
- docs.extend(loader.load())
66
-
67
- # Load YouTube
68
- loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
69
- YOUTUBE_URL_2,
70
- YOUTUBE_URL_3], YOUTUBE_DIR),
71
- OpenAIWhisperParser())
72
- docs.extend(loader.load())
73
-
74
- # Document splitting
75
- text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
76
- chunk_size = config["chunk_size"])
77
- split_documents = text_splitter.split_documents(docs)
78
-
79
- return split_documents
80
-
81
- def document_storage_chroma(documents):
82
- Chroma.from_documents(documents = documents,
83
- embedding = OpenAIEmbeddings(disallowed_special = ()),
84
- persist_directory = CHROMA_DIR)
85
-
86
- def document_storage_mongodb(documents):
87
- MongoDBAtlasVectorSearch.from_documents(documents = documents,
88
- embedding = OpenAIEmbeddings(disallowed_special = ()),
89
- collection = collection,
90
- index_name = MONGODB_INDEX_NAME)
91
-
92
- def document_retrieval_chroma(llm, prompt):
93
- return Chroma(embedding_function = OpenAIEmbeddings(),
94
- persist_directory = CHROMA_DIR)
95
-
96
- def document_retrieval_mongodb(llm, prompt):
97
- return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
98
- MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
99
- OpenAIEmbeddings(disallowed_special = ()),
100
- index_name = MONGODB_INDEX_NAME)
101
-
102
- def llm_chain(llm, prompt):
103
- llm_chain = LLMChain(llm = llm,
104
- prompt = LLM_CHAIN_PROMPT,
105
- verbose = False)
106
- completion = llm_chain.generate([{"question": prompt}])
107
- return completion, llm_chain
108
-
109
- def rag_chain(llm, prompt, db):
110
- rag_chain = RetrievalQA.from_chain_type(llm,
111
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
112
- retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
113
- return_source_documents = True,
114
- verbose = False)
115
- completion = rag_chain({"query": prompt})
116
- return completion, rag_chain
117
-
118
  def invoke(openai_api_key, rag_option, prompt):
119
  if (openai_api_key == ""):
120
  raise gr.Error("OpenAI API Key is required.")
 
2
  import openai, os, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from rag import llm_chain, rag_chain
7
  from trace import wandb_trace
8
 
9
  _ = load_dotenv(find_dotenv())
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  RAG_OFF = "Off"
12
  RAG_CHROMA = "Chroma"
13
  RAG_MONGODB = "MongoDB"
14
 
 
 
 
15
  config = {
16
  "chunk_overlap": 150,
17
  "chunk_size": 1500,
18
  "k": 3,
 
 
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def invoke(openai_api_key, rag_option, prompt):
22
  if (openai_api_key == ""):
23
  raise gr.Error("OpenAI API Key is required.")