gwfast_bot / rag.py
jaywadekar's picture
fixed reference extraction
bd402e1
# Utilities to build a RAG system to query information from the
# gwIAS search pipeline using Langchain
# Thanks to Pablo Villanueva Domingo for sharing his CAMELS template
# https://huggingface.co/spaces/PabloVD/CAMELSDocBot
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain.schema import Document
import requests
import json
import base64
from bs4 import BeautifulSoup
import re
from urllib.parse import urljoin, urlparse
def github_to_raw(url):
"""Convert GitHub URL to raw content URL"""
return url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
def load_github_notebook(url):
"""Load Jupyter notebook from GitHub URL using GitHub API"""
try:
# Convert GitHub blob URL to API URL
if "github.com" in url and "/blob/" in url:
# Extract owner, repo, branch and path from URL
parts = url.replace("https://github.com/", "").split("/")
owner = parts[0]
repo = parts[1]
branch = parts[3] # usually 'main' or 'master'
path = "/".join(parts[4:])
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}"
else:
raise ValueError("URL must be a GitHub blob URL")
# Fetch notebook content
response = requests.get(api_url)
response.raise_for_status()
content_data = response.json()
if content_data.get('encoding') == 'base64':
notebook_content = base64.b64decode(content_data['content']).decode('utf-8')
else:
notebook_content = content_data['content']
# Parse notebook JSON
notebook = json.loads(notebook_content)
docs = []
cell_count = 0
# Process each cell
for cell in notebook.get('cells', []):
cell_count += 1
cell_type = cell.get('cell_type', 'unknown')
source = cell.get('source', [])
# Join source lines
if isinstance(source, list):
content = ''.join(source)
else:
content = str(source)
if content.strip(): # Only add non-empty cells
metadata = {
'source': url,
'cell_type': cell_type,
'cell_number': cell_count,
'name': f"{url} - Cell {cell_count} ({cell_type})"
}
# Add cell type prefix for better context
formatted_content = f"[{cell_type.upper()} CELL {cell_count}]\n{content}"
docs.append(Document(page_content=formatted_content, metadata=metadata))
return docs
except Exception as e:
print(f"Error loading notebook from {url}: {str(e)}")
return []
def clean_text(text):
"""Clean text content from a webpage"""
# Remove excessive newlines
text = re.sub(r'\n{3,}', '\n\n', text)
# Remove excessive whitespace
text = re.sub(r'\s{2,}', ' ', text)
return text.strip()
def clean_github_content(html_content):
"""Extract meaningful content from GitHub pages"""
# Ensure we're working with a BeautifulSoup object
if isinstance(html_content, str):
soup = BeautifulSoup(html_content, 'html.parser')
else:
soup = html_content
# Remove navigation, footer, and other boilerplate
for element in soup.find_all(['nav', 'footer', 'header']):
element.decompose()
# For README and code files
readme_content = soup.find('article', class_='markdown-body')
if readme_content:
return clean_text(readme_content.get_text())
# For code files
code_content = soup.find('table', class_='highlight')
if code_content:
return clean_text(code_content.get_text())
# For directory listings
file_list = soup.find('div', role='grid')
if file_list:
return clean_text(file_list.get_text())
# Fallback to main content
main_content = soup.find('main')
if main_content:
return clean_text(main_content.get_text())
# If no specific content found, get text from body
body = soup.find('body')
if body:
return clean_text(body.get_text())
# Final fallback
return clean_text(soup.get_text())
class GitHubLoader(WebBaseLoader):
"""Custom loader for GitHub pages with better content cleaning"""
def clean_text(self, text):
"""Clean text content"""
# Remove excessive newlines and spaces
text = re.sub(r'\n{2,}', '\n', text)
text = re.sub(r'\s{2,}', ' ', text)
# Remove common GitHub boilerplate
text = re.sub(r'Skip to content|Sign in|Search or jump to|Footer navigation|Terms|Privacy|Security|Status|Docs', '', text)
return text.strip()
def lazy_load(self) -> list[Document]:
"""Override lazy_load instead of _scrape to handle both BeautifulSoup and string returns."""
for url in self.web_paths:
try:
response = requests.get(url)
response.raise_for_status()
# For directory listings (tree URLs), use the API
if '/tree/' in url:
# Parse URL components
parts = url.replace("https://github.com/", "").split("/")
owner = parts[0]
repo = parts[1]
branch = parts[3] # usually 'main' or 'master'
path = "/".join(parts[4:]) if len(parts) > 4 else ""
# Construct API URL
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}"
api_response = requests.get(api_url)
api_response.raise_for_status()
# Parse directory listing
contents = api_response.json()
if isinstance(contents, list):
# Format directory contents
content = "Directory contents:\n" + "\n".join([f"{item['name']} ({item['type']})" for item in contents])
yield Document(
page_content=self.clean_text(content),
metadata={'source': url, 'type': 'github_directory'}
)
continue
# For regular files, parse HTML
soup = BeautifulSoup(response.text, 'html.parser')
# For README and markdown files
readme_content = soup.find('article', class_='markdown-body')
if readme_content:
yield Document(
page_content=self.clean_text(readme_content.get_text()),
metadata={'source': url, 'type': 'github_markdown'}
)
continue
# For code files
code_content = soup.find('table', class_='highlight')
if code_content:
yield Document(
page_content=self.clean_text(code_content.get_text()),
metadata={'source': url, 'type': 'github_code'}
)
continue
# For other content, get main content
main_content = soup.find('main')
if main_content:
yield Document(
page_content=self.clean_text(main_content.get_text()),
metadata={'source': url, 'type': 'github_other'}
)
continue
# Fallback to whole page content
yield Document(
page_content=self.clean_text(soup.get_text()),
metadata={'source': url, 'type': 'github_fallback'}
)
except Exception as e:
print(f"Error processing {url}: {str(e)}")
continue
def load(self) -> list[Document]:
"""Load method that returns a list of documents."""
return list(self.lazy_load())
class ReadTheDocsLoader(WebBaseLoader):
"""Custom loader for ReadTheDocs pages"""
def __init__(self, base_url: str):
"""Initialize with base URL of the documentation."""
super().__init__([])
self.base_url = base_url.rstrip('/')
def clean_text(self, text: str) -> str:
"""Clean text content from ReadTheDocs pages."""
# Remove excessive whitespace and newlines
text = re.sub(r'\s{2,}', ' ', text)
text = re.sub(r'\n{3,}', '\n\n', text)
# Remove common ReadTheDocs boilerplate
text = re.sub(r'View page source|Next|Previous|©.*?\.', '', text)
return text.strip()
def normalize_url(self, base_url: str, href: str) -> str:
"""Normalize relative URLs to absolute URLs."""
# If it's already an absolute URL, return it
if href.startswith(('http://', 'https://')):
return href
# Handle relative URLs
return urljoin(base_url, href)
def get_all_pages(self) -> list[str]:
"""Get all documentation pages starting from the base URL."""
visited = set()
to_visit = {self.base_url}
docs_urls = set()
while to_visit:
url = to_visit.pop()
if url in visited:
continue
visited.add(url)
try:
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
# Add current page if it's a documentation page
if url.startswith(self.base_url):
docs_urls.add(url)
# Find all links
for link in soup.find_all('a'):
href = link.get('href')
if not href:
continue
# Skip anchor links and external links
if href.startswith('#') or href.startswith(('http://', 'https://')) and not href.startswith(self.base_url):
continue
# Normalize the URL
full_url = self.normalize_url(url, href)
# Only follow links within the documentation domain
if full_url.startswith(self.base_url):
to_visit.add(full_url)
except Exception as e:
print(f"Error fetching {url}: {str(e)}")
return list(docs_urls)
def load(self) -> list[Document]:
"""Load all documentation pages."""
urls = self.get_all_pages()
docs = []
for url in urls:
try:
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
# Get main content
main_content = soup.find('div', {'role': 'main'})
if not main_content:
main_content = soup.find('main')
if not main_content:
continue
# Clean content
content = self.clean_text(main_content.get_text())
if content:
docs.append(Document(
page_content=content,
metadata={'source': url, 'type': 'readthedocs'}
))
except Exception as e:
print(f"Error processing {url}: {str(e)}")
return docs
def load_docs():
"""Load all documentation."""
# Get urls
with open("urls.txt", "r") as f:
urls = [line.strip() for line in f.readlines()]
docs = []
# Load GitHub content
for url in urls:
if "github.com" in url or "raw.githubusercontent.com" in url:
if "/blob/" in url and url.endswith(".ipynb"):
# Handle Jupyter notebooks
notebook_docs = load_github_notebook(url)
docs.extend(notebook_docs)
elif "raw.githubusercontent.com" in url:
# Handle raw GitHub content directly
try:
response = requests.get(url)
response.raise_for_status()
content = response.text
docs.append(Document(
page_content=content,
metadata={'source': url, 'type': 'github_raw'}
))
except Exception as e:
print(f"Error loading raw content from {url}: {str(e)}")
else:
# Handle other GitHub content
loader = GitHubLoader([url])
docs.extend(loader.load())
# Load ReadTheDocs content
rtd_loader = ReadTheDocsLoader("https://gwfast.readthedocs.io/en/latest")
docs.extend(rtd_loader.load())
return docs
def extract_reference(url):
"""Extract a reference keyword from the GitHub URL"""
if "blob/main" in url:
return url.split("blob/main/")[-1]
elif "tree/main" in url:
return url.split("tree/main/")[-1] or "root"
elif "blob/master" in url:
return url.split("blob/master/")[-1]
elif "tree/master" in url:
return url.split("tree/master/")[-1] or "root"
elif "refs/heads/master" in url:
return url.split("refs/heads/master/")[-1]
return url
# Join content pages for processing
def format_docs(docs):
formatted_docs = []
for doc in docs:
source = doc.metadata.get('source', 'Unknown source')
reference = f"[{extract_reference(source)}]"
content = doc.page_content
formatted_docs.append(f"{content}\n\nReference: {reference}")
return "\n\n---\n\n".join(formatted_docs)
# Create a RAG chain
def RAG(llm, docs, embeddings):
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create vector store
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
# Retrieve and generate using the relevant snippets of the documents
retriever = vectorstore.as_retriever()
# Prompt basis example for RAG systems
prompt = hub.pull("rlm/rag-prompt")
# Adding custom instructions to the prompt
template = prompt.messages[0].prompt.template
template_parts = template.split("\nQuestion: {question}")
combined_template = "You are an assistant for question-answering tasks. "\
+ "Use the following pieces of retrieved context to answer the question. "\
+ "If you don't know the answer, just say that you don't know. "\
+ "Try to keep the answer concise if possible. "\
+ "Write the names of the relevant functions from the retrived code and include code snippets to aid the user's understanding. "\
+ "Include the references used in square brackets at the end of your answer."\
+ template_parts[1]
prompt.messages[0].prompt.template = combined_template
# Create the chain
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain