Todd Deshane commited on
Commit
a5f4443
·
1 Parent(s): 4097737

add in youtube rag

Browse files
Files changed (1) hide show
  1. tools.py +76 -0
tools.py CHANGED
@@ -55,6 +55,27 @@ def _generate_image(prompt: str):
55
  cl.user_session.set("generated_image", name)
56
  return name
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def generate_image(prompt: str):
60
  image_name = _generate_image(prompt)
@@ -71,3 +92,58 @@ generate_image_tool = Tool.from_function(
71
  description=f"Useful to create an image from a text prompt. Input should be a single string strictly in the following JSON format: {generate_image_format}",
72
  return_direct=True,
73
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  cl.user_session.set("generated_image", name)
56
  return name
57
 
58
+ def _youtube_rag(prompt: str):
59
+ openai.api_key = os.environ["OPENAI_API_KEY"]
60
+ flattened_texts = []
61
+
62
+ #check if db exists
63
+ if os.path.exists(persist_directory):
64
+ #don't process transcripts
65
+ if debug:
66
+ print("Database exists, skipping transcript processing...")
67
+ else:
68
+ print("Database does not exist")
69
+
70
+ if debug:
71
+ print("Initializing database...")
72
+ docsearch = initialize_chroma_db(flattened_texts)
73
+
74
+ docs = docsearch.get_relevant_documents(prompt)
75
+ chat_model = ChatOpenAI(model_name="gpt-4-1106-preview")
76
+ chain = load_qa_chain(llm=chat_model, chain_type="stuff")
77
+ answer = chain.run(input_documents=docs, question=query)
78
+ return answer
79
 
80
  def generate_image(prompt: str):
81
  image_name = _generate_image(prompt)
 
92
  description=f"Useful to create an image from a text prompt. Input should be a single string strictly in the following JSON format: {generate_image_format}",
93
  return_direct=True,
94
  )
95
+
96
+
97
+ import os
98
+ import openai
99
+ from langchain.chat_models import ChatOpenAI
100
+ from langchain.embeddings.openai import OpenAIEmbeddings
101
+ from langchain.vectorstores import Chroma
102
+ from langchain.chains.question_answering import load_qa_chain
103
+ from langchain.text_splitter import CharacterTextSplitter
104
+ from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
105
+
106
+ debug = False
107
+
108
+ persist_directory = 'db'
109
+ embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
110
+
111
+ # Function to initialize or load the Chroma database
112
+ def initialize_chroma_db(texts):
113
+ if os.path.exists(persist_directory):
114
+ # Load existing database
115
+ if debug:
116
+ print("Loading existing database...")
117
+ db = Chroma(persist_directory="./db", embedding_function=embedding_function)
118
+ else:
119
+ # Create and initialize new database
120
+ #embeddings = OpenAIEmbeddings()
121
+ if debug:
122
+ print("Creating new database...")
123
+ db = Chroma.from_texts(texts, embedding_function, persist_directory=persist_directory)
124
+ return db.as_retriever()
125
+
126
+
127
+
128
+ # thisis the youtube rag tool - which is what allows our agent to rag the youtube vector db
129
+ # the `description` field is of utmost importance as it is what the LLM "brain" uses to determine
130
+ # which tool to use for a given input.
131
+ youtube_rag_format = '{{"prompt": "prompt"}}'
132
+ generate_image_tool = Tool.from_function(
133
+ func=youtube_rag,
134
+ name="Youtube_Rag",
135
+ description=f"Useful to query the vector database containing youtube transcripts about Aaron Lebauer. Input should be a single string strictly in the following JSON format: {youtube_rag_format}",
136
+ return_direct=True,
137
+ )
138
+
139
+
140
+ def youtube_rag(prompt: str):
141
+ answer = _youtube_rag(prompt)
142
+ return f" {answer}."
143
+
144
+
145
+
146
+
147
+
148
+
149
+