Spaces:
Runtime error
Runtime error
import os | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_core.documents import Document | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain.schema import StrOutputParser | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_anthropic import ChatAnthropic | |
from dotenv import load_dotenv | |
from langchain_core.output_parsers import XMLOutputParser | |
from langchain.prompts import ChatPromptTemplate | |
import re | |
load_dotenv() | |
# suppress grpc and glog logs for gemini | |
os.environ["GRPC_VERBOSITY"] = "ERROR" | |
os.environ["GLOG_minloglevel"] = "2" | |
# RAG parameters | |
CHUNK_SIZE = 1024 | |
CHUNK_OVERLAP = CHUNK_SIZE // 8 | |
K = 10 | |
FETCH_K = 20 | |
llm_model_translation = { | |
"LLaMA 3": "llama3-70b-8192", | |
"OpenAI GPT 4o Mini": "gpt-4o-mini", | |
"OpenAI GPT 4o": "gpt-4o", | |
"OpenAI GPT 4": "gpt-4-turbo", | |
"Gemini 1.5 Pro": "gemini-1.5-pro", | |
"Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", | |
} | |
llm_classes = { | |
"llama3-70b-8192": ChatGroq, | |
"gpt-4o-mini": ChatOpenAI, | |
"gpt-4o": ChatOpenAI, | |
"gpt-4-turbo": ChatOpenAI, | |
"gemini-1.5-pro": ChatGoogleGenerativeAI, | |
"claude-3-5-sonnet-20240620": ChatAnthropic, | |
} | |
xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, fulfill all the requirements \ | |
of the prompt and provide citations. If a chunk of the generated text does not use any of the sources (for example, \ | |
introductions or general text), don't put a citation for that chunk and just leave "citations" section empty. Otherwise, \ | |
list all sources used for that chunk of the text. Remember, don't add inline citations in the text itself in any circumstant. | |
Add all citations to the separate citations section. Use explicit new lines in the text to show paragraph splits. For each chunk use this example format: | |
<chunk> | |
<text>This is a sample text chunk....</text> | |
<citations> | |
<citation>1</citation> | |
<citation>3</citation> | |
... | |
</citations> | |
</chunk> | |
If the prompt asks for a reference section, add it in a chunk without any citations | |
Return a citation for every quote across all articles that justify the text. Remember use the following format for your final output: | |
<cited_text> | |
<chunk> | |
<text></text> | |
<citations> | |
<citation><source_id></source_id></citation> | |
... | |
</citations> | |
</chunk> | |
<chunk> | |
<text></text> | |
<citations> | |
<citation><source_id></source_id></citation> | |
... | |
</citations> | |
</chunk> | |
... | |
</cited_text> | |
The entire text should be wrapped in one cited_text. For References section (if asked by prompt), don't add citations. | |
For source id, give a valid integer alone without a key. | |
Here are the sources:{context}""" | |
xml_prompt = ChatPromptTemplate.from_messages( | |
[("system", xml_system), ("human", "{input}")] | |
) | |
def format_docs_xml(docs: list[Document]) -> str: | |
formatted = [] | |
for i, doc in enumerate(docs): | |
doc_str = f"""\ | |
<source id=\"{i}\"> | |
<path>{doc.metadata['source']}</path> | |
<article_snippet>{doc.page_content}</article_snippet> | |
</source>""" | |
formatted.append(doc_str) | |
return "\n\n<sources>" + "\n".join(formatted) + "</sources>" | |
def get_doc_content(docs, id): | |
return docs[id].page_content | |
def remove_citations(text): | |
text = re.sub(r'<\d+>', '', text) | |
text = re.sub(r'[\d+]', '', text) | |
return text | |
def process_cited_text(data, docs): | |
# Initialize variables for the combined text and a dictionary for citations | |
combined_text = "" | |
citations = {} | |
# Iterate through the cited_text list | |
if 'cited_text' in data: | |
for item in data['cited_text']: | |
chunk_text = item['chunk'][0]['text'] | |
combined_text += chunk_text | |
citation_ids = [] | |
# Process the citations for the chunk | |
if item['chunk'][1]['citations']: | |
for c in item['chunk'][1]['citations']: | |
if c and 'citation' in c: | |
citation = c['citation'] | |
if isinstance(citation, dict) and "source_id" in citation: | |
citation = citation['source_id'] | |
if isinstance(citation, str): | |
try: | |
citation_ids.append(int(citation)) | |
except ValueError: | |
pass # Handle cases where the string is not a valid integer | |
if citation_ids: | |
citation_texts = [f"<{cid}>" for cid in citation_ids] | |
combined_text += " " + "".join(citation_texts) | |
combined_text += "\n\n" | |
# Store unique citations in a dictionary | |
for citation_id in citation_ids: | |
if citation_id not in citations: | |
citations[citation_id] = {'source': docs[citation_id].metadata['source'], 'content': docs[citation_id].page_content} | |
return combined_text.strip(), citations | |
def citations_to_html(citations): | |
if citations: | |
# Generate the HTML for the unique citations | |
html_content = "" | |
for citation_id, citation_info in citations.items(): | |
html_content += ( | |
f"<li><strong>Source ID:</strong> {citation_id}<br>" | |
f"<strong>Path:</strong> {citation_info['source']}<br>" | |
f"<strong>Page Content:</strong> {citation_info['content']}</li>" | |
) | |
html_content += "</ul></body></html>" | |
return html_content | |
return "" | |
def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): | |
model_name = llm_model_translation.get(model) | |
llm_class = llm_classes.get(model_name) | |
if not llm_class: | |
raise ValueError(f"Model {model} not supported.") | |
try: | |
llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
llm = None | |
return llm | |
def create_db_with_langchain(path: list[str], url_content: dict): | |
all_docs = [] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") | |
if path: | |
for file in path: | |
loader = PyMuPDFLoader(file) | |
data = loader.load() | |
# split it into chunks | |
docs = text_splitter.split_documents(data) | |
all_docs.extend(docs) | |
if url_content: | |
for url, content in url_content.items(): | |
doc = Document(page_content=content, metadata={"source": url}) | |
# split it into chunks | |
docs = text_splitter.split_documents([doc]) | |
all_docs.extend(docs) | |
# print docs | |
for idx, doc in enumerate(all_docs): | |
print(f"Doc: {idx} | Length = {len(doc.page_content)}") | |
assert len(all_docs) > 0, "No PDFs or scrapped data provided" | |
db = Chroma.from_documents(all_docs, embedding_function) | |
return db | |
def generate_rag( | |
prompt: str, | |
topic: str, | |
model: str, | |
url_content: dict, | |
path: list[str], | |
temperature: float = 1.0, | |
max_length: int = 2048, | |
api_key: str = "", | |
sys_message="", | |
): | |
llm = load_llm(model, api_key, temperature, max_length) | |
if llm is None: | |
print("Failed to load LLM. Aborting operation.") | |
return None | |
db = create_db_with_langchain(path, url_content) | |
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) | |
rag_prompt = hub.pull("rlm/rag-prompt") | |
def format_docs(docs): | |
if all(isinstance(doc, Document) for doc in docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
else: | |
raise TypeError("All items in docs must be instances of Document.") | |
docs = retriever.get_relevant_documents(topic) | |
formatted_docs = format_docs_xml(docs) | |
rag_chain = ( | |
RunnablePassthrough.assign(context=lambda _: formatted_docs) | |
| xml_prompt | |
| llm | |
| XMLOutputParser() | |
) | |
result = rag_chain.invoke({"input": prompt}) | |
text, citations = process_cited_text(result, docs) | |
return text, citations | |
def generate_base( | |
prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" | |
): | |
llm = load_llm(model, api_key, temperature, max_length) | |
if llm is None: | |
print("Failed to load LLM. Aborting operation.") | |
return None, None | |
try: | |
output = llm.invoke(prompt).content | |
return output, None | |
except Exception as e: | |
print(f"An error occurred while running the model: {e}") | |
return None, None | |
def generate( | |
prompt: str, | |
topic: str, | |
model: str, | |
url_content: dict, | |
path: list[str], | |
temperature: float = 1.0, | |
max_length: int = 2048, | |
api_key: str = "", | |
sys_message="", | |
): | |
if path or url_content: | |
return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) | |
else: | |
return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message) | |