minko186 commited on
Commit
35710a1
·
verified ·
1 Parent(s): bc325e5

Delete ai_generate

Browse files
Files changed (1) hide show
  1. ai_generate +0 -221
ai_generate DELETED
@@ -1,221 +0,0 @@
1
- import os
2
- from langchain_community.document_loaders import PyMuPDFLoader
3
- from langchain_core.documents import Document
4
- from langchain_community.embeddings.sentence_transformer import (
5
- SentenceTransformerEmbeddings,
6
- )
7
- from langchain.schema import StrOutputParser
8
- from langchain_community.vectorstores import Chroma
9
- from langchain_text_splitters import RecursiveCharacterTextSplitter
10
- from langchain import hub
11
- from langchain_core.output_parsers import StrOutputParser
12
- from langchain_core.runnables import RunnablePassthrough
13
- from langchain_groq import ChatGroq
14
- from langchain_openai import ChatOpenAI
15
- from langchain_google_genai import ChatGoogleGenerativeAI
16
- from langchain_anthropic import ChatAnthropic
17
- from dotenv import load_dotenv
18
- from langchain_core.output_parsers import XMLOutputParser
19
- from langchain.prompts import ChatPromptTemplate
20
-
21
- load_dotenv()
22
-
23
- # suppress grpc and glog logs for gemini
24
- os.environ["GRPC_VERBOSITY"] = "ERROR"
25
- os.environ["GLOG_minloglevel"] = "2"
26
-
27
- # RAG parameters
28
- CHUNK_SIZE = 1024
29
- CHUNK_OVERLAP = CHUNK_SIZE // 8
30
- K = 10
31
- FETCH_K = 20
32
-
33
- llm_model_translation = {
34
- "LLaMA 3": "llama3-70b-8192",
35
- "OpenAI GPT 4o Mini": "gpt-4o-mini",
36
- "OpenAI GPT 4o": "gpt-4o",
37
- "OpenAI GPT 4": "gpt-4-turbo",
38
- "Gemini 1.5 Pro": "gemini-1.5-pro",
39
- "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
40
- }
41
-
42
- llm_classes = {
43
- "llama3-70b-8192": ChatGroq,
44
- "gpt-4o-mini": ChatOpenAI,
45
- "gpt-4o": ChatOpenAI,
46
- "gpt-4-turbo": ChatOpenAI,
47
- "gemini-1.5-pro": ChatGoogleGenerativeAI,
48
- "claude-3-5-sonnet-20240620": ChatAnthropic,
49
- }
50
-
51
- xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \
52
- fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \
53
- not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the answer.
54
- At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID.
55
-
56
-
57
- Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \
58
- justifies the answer and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \
59
- that justify the answer. Use the following format for your final output:
60
-
61
- <cited_answer>
62
- <answer></answer>
63
- <citations>
64
- <citation><source_id></source_id><source></source><quote></quote></citation>
65
- <citation><source_id></source_id><source></source><quote></quote></citation>
66
- ...
67
- </citations>
68
- </cited_answer>
69
-
70
- Here are the sources:{context}"""
71
- xml_prompt = ChatPromptTemplate.from_messages(
72
- [("system", xml_system), ("human", "{input}")]
73
- )
74
-
75
- def format_docs_xml(docs: list[Document]) -> str:
76
- formatted = []
77
- for i, doc in enumerate(docs):
78
- doc_str = f"""\
79
- <source>
80
- <source>{doc.metadata['source']}</source>
81
- <title>{doc.metadata['title']}</title>
82
- <article_snippet>{doc.page_content}</article_snippet>
83
- </source>"""
84
- formatted.append(doc_str)
85
- return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
86
-
87
-
88
- def citations_to_html(citations_data):
89
- if citations_data:
90
- html_output = "<ul>"
91
-
92
- for index, citation in enumerate(citations_data):
93
- source_id = citation['citation'][0]['source_id']
94
- source = citation['citation'][1]['source']
95
- quote = citation['citation'][2]['quote']
96
-
97
- html_output += f"""
98
- <li>
99
- [{index}] - "{source}" <br>
100
- "{quote}"
101
- </li>
102
- """
103
-
104
- html_output += "</ul>"
105
- return html_output
106
- return ""
107
-
108
-
109
- def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
110
- model_name = llm_model_translation.get(model)
111
- llm_class = llm_classes.get(model_name)
112
- if not llm_class:
113
- raise ValueError(f"Model {model} not supported.")
114
- try:
115
- llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
116
- except Exception as e:
117
- print(f"An error occurred: {e}")
118
- llm = None
119
- return llm
120
-
121
-
122
- def create_db_with_langchain(path: list[str], url_content: dict):
123
- all_docs = []
124
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
125
- embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
126
- if path:
127
- for file in path:
128
- loader = PyMuPDFLoader(file)
129
- data = loader.load()
130
- # split it into chunks
131
- docs = text_splitter.split_documents(data)
132
- all_docs.extend(docs)
133
-
134
- if url_content:
135
- for url, content in url_content.items():
136
- doc = Document(page_content=content, metadata={"source": url})
137
- # split it into chunks
138
- docs = text_splitter.split_documents([doc])
139
- all_docs.extend(docs)
140
-
141
- # print docs
142
- for idx, doc in enumerate(all_docs):
143
- print(f"Doc: {idx} | Length = {len(doc.page_content)}")
144
-
145
- assert len(all_docs) > 0, "No PDFs or scrapped data provided"
146
- db = Chroma.from_documents(all_docs, embedding_function)
147
- return db
148
-
149
-
150
- def generate_rag(
151
- prompt: str,
152
- topic: str,
153
- model: str,
154
- url_content: dict,
155
- path: list[str],
156
- temperature: float = 1.0,
157
- max_length: int = 2048,
158
- api_key: str = "",
159
- sys_message="",
160
- ):
161
- llm = load_llm(model, api_key, temperature, max_length)
162
- if llm is None:
163
- print("Failed to load LLM. Aborting operation.")
164
- return None
165
- db = create_db_with_langchain(path, url_content)
166
- retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
167
- rag_prompt = hub.pull("rlm/rag-prompt")
168
-
169
- def format_docs(docs):
170
- if all(isinstance(doc, Document) for doc in docs):
171
- return "\n\n".join(doc.page_content for doc in docs)
172
- else:
173
- raise TypeError("All items in docs must be instances of Document.")
174
-
175
- docs = retriever.get_relevant_documents(topic)
176
- # formatted_docs = format_docs(docs)
177
- # rag_chain = (
178
- # {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
179
- # )
180
- # return rag_chain.invoke(prompt)
181
-
182
- formatted_docs = format_docs_xml(docs)
183
- rag_chain = (
184
- RunnablePassthrough.assign(context=lambda _: formatted_docs)
185
- | xml_prompt
186
- | llm
187
- | XMLOutputParser()
188
- )
189
- result = rag_chain.invoke({"input": prompt})
190
- print(result)
191
- return result['cited_answer'][0]['answer'], result['cited_answer'][1]['citations']
192
-
193
-
194
- def generate_base(
195
- prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
196
- ):
197
- llm = load_llm(model, api_key, temperature, max_length)
198
- if llm is None:
199
- print("Failed to load LLM. Aborting operation.")
200
- return None, None
201
- try:
202
- output = llm.invoke(prompt).content
203
- return output, None
204
- except Exception as e:
205
- print(f"An error occurred while running the model: {e}")
206
- return None, None
207
-
208
-
209
- def generate(
210
- prompt: str,
211
- topic: str,
212
- model: str,
213
- url_content: dict,
214
- path: list[str],
215
- temperature: float = 1.0,
216
- max_length: int = 2048,
217
- api_key: str = "",
218
- sys_message="",
219
- ):
220
- return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
221
-