Spaces:
Build error
Build error
import os | |
import json | |
import shutil | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
import requests | |
import tarfile | |
from langchain.schema import Document | |
import hashlib | |
import xml.etree.ElementTree as ET | |
from urllib import request | |
from s3_utils import S3Handler | |
from config import get_settings | |
from PyPDF2 import PdfReader | |
class PubMedDownloader: | |
def __init__(self, s3_handler, pubmed_base_url, pinecone_index, embedding_model, from_date="2024-01-01", until_date="2024-11-01", limit=3): | |
self.s3_handler = s3_handler | |
self.settings = get_settings() | |
self.pubmed_base_url = pubmed_base_url | |
self.from_date = from_date | |
self.until_date = until_date | |
self.limit = limit | |
self.local_download_dir = "downloaded_pdfs" | |
os.makedirs(self.local_download_dir, exist_ok=True) | |
self.pinecone_index = pinecone_index # Pinecone index instance | |
self.embedding_model = embedding_model # Embedding model instance | |
def split_and_embed(self, documents, metadata_entry): | |
"""Split documents into chunks and embed them sequentially.""" | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.settings.CHUNK_SIZE, | |
chunk_overlap=self.settings.CHUNK_OVERLAP | |
) | |
chunks = text_splitter.split_documents(documents) | |
print(f'total chunks created: {len(chunks)}') | |
batch_size = 50 | |
pmc_id = metadata_entry['pmc_id'] | |
for batch_index in range(0, len(chunks), batch_size): | |
batch = chunks[batch_index: batch_index + batch_size] | |
print(f'len of batch: {len(batch)}') | |
try: | |
# Process a single batch | |
# Create ids for the batch | |
# ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))] | |
ids = [f"{pmc_id}_chunk_{batch_index}_{j}" for j in range(len(batch))] | |
print(f'len of ids: {len(ids)}') | |
print(f'id sample: {ids[0]}') | |
# Get texts and generate embeddings | |
texts = [doc.page_content for doc in batch] | |
print(f'len of texts: {len(texts)}') | |
embeddings = self.embedding_model.embed_documents(texts) | |
metadata = [] | |
for doc in batch: | |
chunk_metadata = metadata_entry.copy() # Copy base metadata | |
chunk_metadata["text"] = doc.page_content # Add chunk-specific text | |
metadata.append(chunk_metadata) | |
# Create upsert batch | |
to_upsert = list(zip(ids, embeddings, metadata)) | |
# Upsert to Pinecone | |
self.pinecone_index.upsert(vectors=to_upsert) | |
print(f"Successfully upserted {len(to_upsert)} chunks to Pinecone.") | |
except Exception as e: | |
print(f"Error processing batch {batch_index}: {e}") | |
def fetch_records(self, resumption_token=None): | |
""" | |
Fetch records from PubMed using optional resumptionToken. | |
Args: | |
resumption_token (str, optional): Token to resume fetching records. Defaults to None. | |
Returns: | |
ElementTree.Element: Parsed XML root of the API response. | |
""" | |
# Build the base URL | |
url = f"{self.pubmed_base_url}" | |
# Define parameters | |
params = { | |
"format" : "tgz" | |
} | |
# Add date range if provided | |
if self.from_date and self.until_date: | |
params["from"] = self.from_date | |
params["until"] = self.until_date | |
# Add resumptionToken if available | |
if resumption_token: | |
params["resumptionToken"] = resumption_token | |
print(f"Using resumption token: {resumption_token}") | |
# Make the request | |
response = requests.get(url, params=params) | |
response.raise_for_status() # Raise an error for bad HTTP responses | |
# Parse and return the XML content | |
return ET.fromstring(response.content) | |
def save_metadata_to_s3(self, metadata, bucket, key): | |
print(f"Saving metadata to S3: s3://{bucket}/{key}") | |
self.s3_handler.upload_string_to_s3(metadata, bucket, key) | |
def save_pdf_to_s3(self, local_filename, bucket, s3_key): | |
"""Upload PDF to S3 and then delete the local file.""" | |
print(f"Uploading PDF to S3: s3://{bucket}/{s3_key}") | |
self.s3_handler.upload_file_to_s3(local_filename, bucket, s3_key) | |
# Delete the local file after upload | |
if os.path.exists(local_filename): | |
os.remove(local_filename) | |
print(f"Deleted local file: {local_filename}") | |
else: | |
print(f"File not found for deletion: {local_filename}") | |
def update_metadata_and_upload(self, metadata_entry, bucket_name, metadata_file_key): | |
"""Update metadata list with a new entry and upload it to S3 as JSON.""" | |
# Add new entry to metadata | |
# Convert metadata to JSON and upload to S3 | |
metadata_json = json.dumps(metadata_entry, indent=4) | |
self.s3_handler.upload_string_to_s3(metadata_json, bucket_name, metadata_file_key) | |
print(f"Updated metadata uploaded to s3://{bucket_name}/{metadata_file_key}") | |
def download_and_process_tgz(self, ftp_link, pmc_id): | |
try: | |
metadata_entry = {} | |
# Step 1: Download TGZ | |
local_tgz_filename = os.path.join(self.local_download_dir, f"{pmc_id}.tgz") | |
print(f"Downloading TGZ: {ftp_link} saving in {local_tgz_filename}") | |
request.urlretrieve(ftp_link, local_tgz_filename) | |
# Step 2: Extract TGZ into a temporary directory | |
temp_extract_dir = os.path.join(self.local_download_dir, f"{pmc_id}_temp") | |
os.makedirs(temp_extract_dir, exist_ok=True) | |
print(f"Temporary extract dir: {temp_extract_dir}") | |
with tarfile.open(local_tgz_filename, "r:gz") as tar: | |
tar.extractall(path=temp_extract_dir) | |
# Step 3: Handle Nested Structure (Move Contents to Target Directory) | |
final_extract_dir = os.path.join(self.local_download_dir, pmc_id) | |
os.makedirs(final_extract_dir, exist_ok=True) | |
# Check if the archive creates a single root directory (e.g., PMC8419487/) | |
extracted_items = os.listdir(temp_extract_dir) | |
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_extract_dir, extracted_items[0])): | |
# Move contents of the single folder to the final directory | |
nested_dir = os.path.join(temp_extract_dir, extracted_items[0]) | |
for item in os.listdir(nested_dir): | |
shutil.move(os.path.join(nested_dir, item), final_extract_dir) | |
else: | |
# If no single root folder, move all files directly | |
for item in extracted_items: | |
shutil.move(os.path.join(temp_extract_dir, item), final_extract_dir) | |
print(f"Final extracted dir: {final_extract_dir}") | |
# Clean up the temporary extraction directory | |
shutil.rmtree(temp_extract_dir) | |
print(f"Temporary extract dir deleted: {temp_extract_dir}") | |
# Process the extracted files as before... | |
xml_file = [f for f in os.listdir(final_extract_dir) if f.endswith(".xml") or f.endswith(".nxml")] | |
pdf_path = [f for f in os.listdir(final_extract_dir) if f.endswith("pdf")] | |
if xml_file: | |
xml_path = os.path.join(final_extract_dir, xml_file[0]) | |
metadata_entry = self.process_xml_metadata(xml_path, pmc_id) | |
else: | |
print(f"No XML file found in TGZ for PMCID: {pmc_id}") | |
print(f'Skipping article') | |
if pdf_path: | |
pdf_path = os.path.join(final_extract_dir, pdf_path[0]) | |
document = self.download_and_process_pdf(pdf_path, pmc_id, self.settings.AWS_BUCKET_NAME) | |
else: | |
if metadata_entry.get('body_text') and metadata_entry['body_text'] != "N/A": | |
document = Document( | |
page_content=metadata_entry['body_text'], metadata=metadata_entry | |
) | |
metadata_entry.pop("body_text") | |
else: | |
print(f'Body content and PDF both not found, hence skipping this PDF') | |
document = None | |
# Cleanup: Remove the downloaded TGZ file | |
if os.path.exists(local_tgz_filename): | |
os.remove(local_tgz_filename) | |
print(f"Removed file: {local_tgz_filename}") | |
if os.path.exists(final_extract_dir): | |
shutil.rmtree(final_extract_dir) | |
return metadata_entry, document | |
except Exception as e: | |
print(f"Cannot download TGZ file for {pmc_id} : ftp link : {ftp_link}") | |
print(f"[ERROR] {str(e)}") | |
return {}, None | |
def extract_text_from_element(self, element): | |
""" | |
Recursively extract all text from an XML element and its children. | |
Args: | |
element (Element): XML element to extract text from. | |
Returns: | |
str: Concatenated text content of the element and its children. | |
""" | |
text_content = element.text or "" # Start with the element's own text | |
for child in element: | |
text_content += self.extract_text_from_element(child) # Recurse into children | |
if child.tail: # Include any tail text after the child element | |
text_content += child.tail | |
return text_content.strip() | |
def process_xml_metadata(self, xml_path, pmc_id): | |
tree = ET.parse(xml_path) | |
root = tree.getroot() | |
# Extract metadata | |
title_elem = root.find(".//article-title") | |
title = title_elem.text if title_elem is not None else "No Title Available" | |
# title = root.find(".//article-title").text if root.find(".//article-title") else "No Title Available" | |
# abstract = root.find(".//abstract/p").text if root.find(".//abstract/p") else "No Abstract Available" | |
# Abstract extraction | |
abstract_elem = root.find(".//abstract/p") | |
abstract = abstract_elem.text if abstract_elem is not None else "No Abstract Available" | |
# doi = root.find(".//article-id[@pub-id-type='doi']").text if root.find(".//article-id[@pub-id-type='doi']") else "N/A" | |
# DOI extraction | |
doi_elem = root.find(".//article-id[@pub-id-type='doi']") | |
doi = doi_elem.text if doi_elem is not None else "N/A" | |
# authors = [f"{author.find('surname').text}, {author.find('given-names').text}" | |
# for author in root.findall(".//contrib/name")] | |
authors = [] | |
for author in root.findall(".//contrib/name"): | |
surname = author.find('surname') | |
given_names = author.find('given-names') | |
# Safely handle missing elements | |
surname_text = surname.text if surname is not None else "Unknown Surname" | |
given_names_text = given_names.text if given_names is not None else "Unknown Given Names" | |
authors.append(f"{surname_text}, {given_names_text}") | |
keywords = [kw.text for kw in root.findall(".//kwd")] | |
# Extract publication date | |
pub_date_node = root.find(".//pub-date") | |
if pub_date_node is not None: | |
month = pub_date_node.find("month").text if pub_date_node.find("month") is not None else "N/A" | |
year = pub_date_node.find("year").text if pub_date_node.find("year") is not None else "N/A" | |
pub_type = pub_date_node.attrib.get("pub-type", "N/A") | |
publication_date = f"{year}-{month}" if month != "N/A" else year | |
else: | |
publication_date = "N/A" | |
# Extract text content from <body> | |
body_node = root.find(".//body") | |
body_text = "" | |
if body_node is not None: | |
body_text = self.extract_text_from_element(body_node) | |
else: | |
body_text = "N/A" | |
# Save enriched metadata | |
metadata_entry = { | |
"pmc_id": pmc_id, | |
"title": title, | |
"abstract": abstract, | |
"authors": authors, | |
"keywords": keywords, | |
"doi": doi, | |
"source": f"https://pmc.ncbi.nlm.nih.gov/articles/{pmc_id}", | |
"publication_date" : publication_date, | |
"body_text" : body_text | |
} | |
return metadata_entry | |
def download_and_process_pdf(self, pdf_path, pmc_id, bucket_name): | |
try: | |
pdf_reader = PdfReader(pdf_path) | |
text = "".join(page.extract_text() for page in pdf_reader.pages) | |
# Create document object | |
document = Document( | |
page_content=text, | |
metadata={"source": f"s3://{bucket_name}/{pmc_id}.pdf"} | |
) | |
return document | |
except Exception as e: | |
print(f"Error processing PDF for {pmc_id}: {e}") | |
return None | |
def process_and_save(self, bucket_name, metadata_file_key): | |
# Load existing metadata from S3 | |
try: | |
metadata_content = self.s3_handler.download_string_from_s3(bucket_name, metadata_file_key) | |
existing_metadata = json.loads(metadata_content) | |
existing_ids = {record["pmc_id"] for record in existing_metadata} | |
print(f"Found {len(existing_ids)} existing records in metadata.") | |
except Exception as e: | |
# If metadata file doesn't exist or is empty, initialize an empty list | |
print(f"Could not load metadata: {e}. Assuming no existing records.") | |
existing_metadata = [] | |
existing_ids = set() | |
resumption_token = None | |
while True: | |
root = self.fetch_records(resumption_token=resumption_token) | |
print(f'len of records: {len(root.findall(".//record"))}') | |
resumption = root.find(".//resumption") | |
print(f'resumption token: {resumption}') | |
for record in root.findall(".//record"): | |
# print(f'first record: ') | |
pmc_id = record.attrib.get("id") | |
# print(f'[INFO] pmc id : {pmc_id}') | |
if pmc_id in existing_ids: | |
# print(f"Skipping already downloaded record: {pmc_id}") | |
continue | |
pdf_link = None | |
ftp_link = None | |
for link in record.findall("link"): | |
if link.attrib.get("format") == "tgz": | |
ftp_link = link.attrib.get("href") | |
if link.attrib.get("format") == "pdf": | |
pdf_link = link.attrib.get("href") | |
print(f'[INFO] links found: pdf {pdf_link} and ftp {ftp_link}') | |
metadata = { } | |
# Process `tgz` first if available | |
if ftp_link: | |
metadata, document = self.download_and_process_tgz(ftp_link, pmc_id) | |
# documents.append(document) | |
if not document: | |
# print(f'this document doesnt have content. continue .. ') | |
continue | |
self.split_and_embed([document], metadata) | |
# Create document object | |
existing_metadata.append(metadata) | |
self.update_metadata_and_upload(existing_metadata, bucket_name , metadata_file_key) | |
resumption = root.find(".//resumption") | |
if resumption is not None: | |
link = resumption.find("link") | |
if link is not None: | |
resumption_token = link.attrib.get("token", "").strip() | |
if not resumption_token: | |
print("No more tokens found, stopping pagination.") | |
break | |
else: | |
print("No link found, stopping pagination.") | |
break | |
else: | |
print("No resumption element, stopping pagination.") | |
break | |
def create_or_connect_index(index_name, dimension): | |
pc = pinecone.Pinecone(settings.PINECONE_API_KEY) | |
"""Create or connect to existing Pinecone index""" | |
spec = pinecone.ServerlessSpec( | |
cloud=settings.CLOUD, | |
region=settings.REGION | |
) | |
print(f'all indexes: {pc.list_indexes()}') | |
if index_name not in pc.list_indexes().names(): | |
pc.create_index( | |
name=index_name, | |
dimension=dimension, | |
metric='cosine', # You can use 'dotproduct' or other metrics if needed | |
spec=spec | |
) | |
return pc.Index(settings.INDEX_NAME) | |
if __name__ == "__main__": | |
""" | |
#todo: add all args as argument parser | |
#todo: like from and until date, and all variables | |
#todo: add one variable like how many iterations we need to go | |
""" | |
# Load settings | |
settings = get_settings() | |
# Initialize S3 handler | |
s3_handler = S3Handler() | |
import pinecone | |
pc_index = create_or_connect_index(settings.INDEX_NAME, settings.DIMENSIONS) | |
# Create the downloader instance | |
downloader = PubMedDownloader( | |
s3_handler=s3_handler, | |
pubmed_base_url=settings.PUBMED_BASE_URL, | |
pinecone_index= pc_index, | |
embedding_model=OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY) | |
) | |
# Process and save | |
downloader.process_and_save( | |
bucket_name=settings.AWS_BUCKET_NAME, | |
metadata_file_key="pubmed_metadata/metadata.json" | |
) | |