minko186 commited on
Commit
e349d37
·
verified ·
1 Parent(s): f610ce3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.document_loaders import PyMuPDFLoader
3
+ from langchain_core.documents import Document
4
+ from langchain_community.embeddings.sentence_transformer import (
5
+ SentenceTransformerEmbeddings,
6
+ )
7
+ from langchain.schema import StrOutputParser
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langchain import hub
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.runnables import RunnablePassthrough
13
+ from langchain_groq import ChatGroq
14
+ from langchain_openai import ChatOpenAI
15
+ from langchain_google_genai import ChatGoogleGenerativeAI
16
+ from langchain_anthropic import ChatAnthropic
17
+ from dotenv import load_dotenv
18
+ from langchain_core.output_parsers import XMLOutputParser
19
+ from langchain.prompts import ChatPromptTemplate
20
+
21
+ load_dotenv()
22
+
23
+ # suppress grpc and glog logs for gemini
24
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
25
+ os.environ["GLOG_minloglevel"] = "2"
26
+
27
+ # RAG parameters
28
+ CHUNK_SIZE = 1024
29
+ CHUNK_OVERLAP = CHUNK_SIZE // 8
30
+ K = 10
31
+ FETCH_K = 20
32
+
33
+ llm_model_translation = {
34
+ "LLaMA 3": "llama3-70b-8192",
35
+ "OpenAI GPT 4o Mini": "gpt-4o-mini",
36
+ "OpenAI GPT 4o": "gpt-4o",
37
+ "OpenAI GPT 4": "gpt-4-turbo",
38
+ "Gemini 1.5 Pro": "gemini-1.5-pro",
39
+ "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
40
+ }
41
+
42
+ llm_classes = {
43
+ "llama3-70b-8192": ChatGroq,
44
+ "gpt-4o-mini": ChatOpenAI,
45
+ "gpt-4o": ChatOpenAI,
46
+ "gpt-4-turbo": ChatOpenAI,
47
+ "gemini-1.5-pro": ChatGoogleGenerativeAI,
48
+ "claude-3-5-sonnet-20240620": ChatAnthropic,
49
+ }
50
+
51
+ xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \
52
+ fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \
53
+ not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the answer.
54
+ At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID.
55
+
56
+
57
+ Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \
58
+ justifies the answer and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \
59
+ that justify the answer. Use the following format for your final output:
60
+
61
+ <cited_answer>
62
+ <answer></answer>
63
+ <citations>
64
+ <citation><source_id></source_id><source></source><quote></quote></citation>
65
+ <citation><source_id></source_id><source></source><quote></quote></citation>
66
+ ...
67
+ </citations>
68
+ </cited_answer>
69
+
70
+ Here are the sources:{context}"""
71
+ xml_prompt = ChatPromptTemplate.from_messages(
72
+ [("system", xml_system), ("human", "{input}")]
73
+ )
74
+
75
+ def format_docs_xml(docs: list[Document]) -> str:
76
+ formatted = []
77
+ for i, doc in enumerate(docs):
78
+ doc_str = f"""\
79
+ <source>
80
+ <source>{doc.metadata['source']}</source>
81
+ <title>{doc.metadata['title']}</title>
82
+ <article_snippet>{doc.page_content}</article_snippet>
83
+ </source>"""
84
+ formatted.append(doc_str)
85
+ return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
86
+
87
+
88
+ def citations_to_html(citations_data):
89
+ if citations_data:
90
+ html_output = "<ul>"
91
+
92
+ for index, citation in enumerate(citations_data):
93
+ source_id = citation['citation'][0]['source_id']
94
+ source = citation['citation'][1]['source']
95
+ quote = citation['citation'][2]['quote']
96
+
97
+ html_output += f"""
98
+ <li>
99
+ [{index}] - "{source}" <br>
100
+ "{quote}"
101
+ </li>
102
+ """
103
+
104
+ html_output += "</ul>"
105
+ return html_output
106
+ return ""
107
+
108
+
109
+ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
110
+ model_name = llm_model_translation.get(model)
111
+ llm_class = llm_classes.get(model_name)
112
+ if not llm_class:
113
+ raise ValueError(f"Model {model} not supported.")
114
+ try:
115
+ llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
116
+ except Exception as e:
117
+ print(f"An error occurred: {e}")
118
+ llm = None
119
+ return llm
120
+
121
+
122
+ def create_db_with_langchain(path: list[str], url_content: dict):
123
+ all_docs = []
124
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
125
+ embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
126
+ if path:
127
+ for file in path:
128
+ loader = PyMuPDFLoader(file)
129
+ data = loader.load()
130
+ # split it into chunks
131
+ docs = text_splitter.split_documents(data)
132
+ all_docs.extend(docs)
133
+
134
+ if url_content:
135
+ for url, content in url_content.items():
136
+ doc = Document(page_content=content, metadata={"source": url})
137
+ # split it into chunks
138
+ docs = text_splitter.split_documents([doc])
139
+ all_docs.extend(docs)
140
+
141
+ # print docs
142
+ for idx, doc in enumerate(all_docs):
143
+ print(f"Doc: {idx} | Length = {len(doc.page_content)}")
144
+
145
+ assert len(all_docs) > 0, "No PDFs or scrapped data provided"
146
+ db = Chroma.from_documents(all_docs, embedding_function)
147
+ return db
148
+
149
+
150
+ def generate_rag(
151
+ prompt: str,
152
+ topic: str,
153
+ model: str,
154
+ url_content: dict,
155
+ path: list[str],
156
+ temperature: float = 1.0,
157
+ max_length: int = 2048,
158
+ api_key: str = "",
159
+ sys_message="",
160
+ ):
161
+ llm = load_llm(model, api_key, temperature, max_length)
162
+ if llm is None:
163
+ print("Failed to load LLM. Aborting operation.")
164
+ return None
165
+ db = create_db_with_langchain(path, url_content)
166
+ retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
167
+ rag_prompt = hub.pull("rlm/rag-prompt")
168
+
169
+ def format_docs(docs):
170
+ if all(isinstance(doc, Document) for doc in docs):
171
+ return "\n\n".join(doc.page_content for doc in docs)
172
+ else:
173
+ raise TypeError("All items in docs must be instances of Document.")
174
+
175
+ docs = retriever.get_relevant_documents(topic)
176
+ # formatted_docs = format_docs(docs)
177
+ # rag_chain = (
178
+ # {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
179
+ # )
180
+ # return rag_chain.invoke(prompt)
181
+
182
+ formatted_docs = format_docs_xml(docs)
183
+ rag_chain = (
184
+ RunnablePassthrough.assign(context=lambda _: formatted_docs)
185
+ | xml_prompt
186
+ | llm
187
+ | XMLOutputParser()
188
+ )
189
+ result = rag_chain.invoke({"input": prompt})
190
+ return result
191
+
192
+
193
+ def process_input(topic, length, tone, format_, pdfs):
194
+ # Construct the prompt
195
+ prompt = f"Write a {format_} about {topic} in about {length} words and a {tone} tone."
196
+
197
+ # Generate the text and citations using RAG
198
+ rag_output = generate_rag(
199
+ prompt=prompt,
200
+ topic=topic,
201
+ model="OpenAI GPT 4o", # Replace with your model name or path
202
+ url_content=None,
203
+ path=pdfs,
204
+ temperature=1.0,
205
+ max_length=2048,
206
+ api_key="", # Add your API key if necessary
207
+ sys_message=""
208
+ )
209
+
210
+ # Extract generated text and citations (Assuming rag_output is a dict-like object with these keys)
211
+ generated_text = rag_output.get('answer', '')
212
+ citations = rag_output.get('citations', '')
213
+
214
+ return generated_text, citations
215
+
216
+
217
+ def generate(
218
+ prompt: str,
219
+ topic: str,
220
+ model: str,
221
+ url_content: dict,
222
+ path: list[str],
223
+ temperature: float = 1.0,
224
+ max_length: int = 2048,
225
+ api_key: str = "",
226
+ sys_message="",
227
+ ):
228
+ return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
229
+
230
+ def create_app():
231
+ with gr.Blocks() as app:
232
+ with gr.Row():
233
+ topic_input = gr.Textbox(
234
+ label="Topic",
235
+ placeholder="Enter the main topic of your article",
236
+ elem_classes="input-highlight-pink",
237
+ )
238
+ length_input = gr.Slider(
239
+ minimum=50,
240
+ maximum=5000,
241
+ step=50,
242
+ value=300,
243
+ label="Article Length",
244
+ elem_classes="input-highlight-pink",
245
+ )
246
+ tone_input = gr.Dropdown(
247
+ choices=[
248
+ "Formal",
249
+ "Informal",
250
+ "Technical",
251
+ "Conversational",
252
+ "Journalistic",
253
+ "Academic",
254
+ "Creative",
255
+ ],
256
+ value="Formal",
257
+ label="Writing Style",
258
+ elem_classes="input-highlight-yellow",
259
+ )
260
+ format_input = gr.Dropdown(
261
+ choices=[
262
+ "Article",
263
+ "Essay",
264
+ "Blog post",
265
+ "Report",
266
+ "Research paper",
267
+ "News article",
268
+ "White paper",
269
+ "Email",
270
+ "LinkedIn post",
271
+ "X (Twitter) post",
272
+ "Instagram Video Content",
273
+ "TikTok Video Content",
274
+ "Facebook post",
275
+ ],
276
+ value="Article",
277
+ label="Format",
278
+ elem_classes="input-highlight-turquoise",
279
+ )
280
+
281
+ pdf_input = gr.File(label="Upload PDFs", file_types=["pdf"], multiple=True)
282
+ generate_button = gr.Button("Generate")
283
+
284
+ generated_text_output = gr.Textbox(label="Generated Text", lines=10)
285
+ citations_output = gr.Textbox(label="Citations", lines=10)
286
+
287
+ generate_button.click(
288
+ fn=process_input,
289
+ inputs=[topic_input, length_input, tone_input, format_input, pdf_input],
290
+ outputs=[generated_text_output, citations_output]
291
+ )
292
+
293
+ return app
294
+
295
+ # Run the app
296
+ app = create_app()
297
+ app.launch()