import os import json import shutil from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter import requests import tarfile from langchain.schema import Document import hashlib import xml.etree.ElementTree as ET from urllib import request from s3_utils import S3Handler from config import get_settings from PyPDF2 import PdfReader class PubMedDownloader: def __init__(self, s3_handler, pubmed_base_url, pinecone_index, embedding_model, from_date="2024-01-01", until_date="2024-11-01", limit=3): self.s3_handler = s3_handler self.settings = get_settings() self.pubmed_base_url = pubmed_base_url self.from_date = from_date self.until_date = until_date self.limit = limit self.local_download_dir = "downloaded_pdfs" os.makedirs(self.local_download_dir, exist_ok=True) self.pinecone_index = pinecone_index # Pinecone index instance self.embedding_model = embedding_model # Embedding model instance def split_and_embed(self, documents, metadata_entry): """Split documents into chunks and embed them sequentially.""" text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.settings.CHUNK_SIZE, chunk_overlap=self.settings.CHUNK_OVERLAP ) chunks = text_splitter.split_documents(documents) print(f'total chunks created: {len(chunks)}') batch_size = 50 pmc_id = metadata_entry['pmc_id'] for batch_index in range(0, len(chunks), batch_size): batch = chunks[batch_index: batch_index + batch_size] print(f'len of batch: {len(batch)}') try: # Process a single batch # Create ids for the batch # ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))] ids = [f"{pmc_id}_chunk_{batch_index}_{j}" for j in range(len(batch))] print(f'len of ids: {len(ids)}') print(f'id sample: {ids[0]}') # Get texts and generate embeddings texts = [doc.page_content for doc in batch] print(f'len of texts: {len(texts)}') embeddings = self.embedding_model.embed_documents(texts) metadata = [] for doc in batch: chunk_metadata = metadata_entry.copy() # Copy base metadata chunk_metadata["text"] = doc.page_content # Add chunk-specific text metadata.append(chunk_metadata) # Create upsert batch to_upsert = list(zip(ids, embeddings, metadata)) # Upsert to Pinecone self.pinecone_index.upsert(vectors=to_upsert) print(f"Successfully upserted {len(to_upsert)} chunks to Pinecone.") except Exception as e: print(f"Error processing batch {batch_index}: {e}") def fetch_records(self, resumption_token=None): """ Fetch records from PubMed using optional resumptionToken. Args: resumption_token (str, optional): Token to resume fetching records. Defaults to None. Returns: ElementTree.Element: Parsed XML root of the API response. """ # Build the base URL url = f"{self.pubmed_base_url}" # Define parameters params = { "format" : "tgz" } # Add date range if provided if self.from_date and self.until_date: params["from"] = self.from_date params["until"] = self.until_date # Add resumptionToken if available if resumption_token: params["resumptionToken"] = resumption_token print(f"Using resumption token: {resumption_token}") # Make the request response = requests.get(url, params=params) response.raise_for_status() # Raise an error for bad HTTP responses # Parse and return the XML content return ET.fromstring(response.content) def save_metadata_to_s3(self, metadata, bucket, key): print(f"Saving metadata to S3: s3://{bucket}/{key}") self.s3_handler.upload_string_to_s3(metadata, bucket, key) def save_pdf_to_s3(self, local_filename, bucket, s3_key): """Upload PDF to S3 and then delete the local file.""" print(f"Uploading PDF to S3: s3://{bucket}/{s3_key}") self.s3_handler.upload_file_to_s3(local_filename, bucket, s3_key) # Delete the local file after upload if os.path.exists(local_filename): os.remove(local_filename) print(f"Deleted local file: {local_filename}") else: print(f"File not found for deletion: {local_filename}") def update_metadata_and_upload(self, metadata_entry, bucket_name, metadata_file_key): """Update metadata list with a new entry and upload it to S3 as JSON.""" # Add new entry to metadata # Convert metadata to JSON and upload to S3 metadata_json = json.dumps(metadata_entry, indent=4) self.s3_handler.upload_string_to_s3(metadata_json, bucket_name, metadata_file_key) print(f"Updated metadata uploaded to s3://{bucket_name}/{metadata_file_key}") def download_and_process_tgz(self, ftp_link, pmc_id): try: metadata_entry = {} # Step 1: Download TGZ local_tgz_filename = os.path.join(self.local_download_dir, f"{pmc_id}.tgz") print(f"Downloading TGZ: {ftp_link} saving in {local_tgz_filename}") request.urlretrieve(ftp_link, local_tgz_filename) # Step 2: Extract TGZ into a temporary directory temp_extract_dir = os.path.join(self.local_download_dir, f"{pmc_id}_temp") os.makedirs(temp_extract_dir, exist_ok=True) print(f"Temporary extract dir: {temp_extract_dir}") with tarfile.open(local_tgz_filename, "r:gz") as tar: tar.extractall(path=temp_extract_dir) # Step 3: Handle Nested Structure (Move Contents to Target Directory) final_extract_dir = os.path.join(self.local_download_dir, pmc_id) os.makedirs(final_extract_dir, exist_ok=True) # Check if the archive creates a single root directory (e.g., PMC8419487/) extracted_items = os.listdir(temp_extract_dir) if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_extract_dir, extracted_items[0])): # Move contents of the single folder to the final directory nested_dir = os.path.join(temp_extract_dir, extracted_items[0]) for item in os.listdir(nested_dir): shutil.move(os.path.join(nested_dir, item), final_extract_dir) else: # If no single root folder, move all files directly for item in extracted_items: shutil.move(os.path.join(temp_extract_dir, item), final_extract_dir) print(f"Final extracted dir: {final_extract_dir}") # Clean up the temporary extraction directory shutil.rmtree(temp_extract_dir) print(f"Temporary extract dir deleted: {temp_extract_dir}") # Process the extracted files as before... xml_file = [f for f in os.listdir(final_extract_dir) if f.endswith(".xml") or f.endswith(".nxml")] pdf_path = [f for f in os.listdir(final_extract_dir) if f.endswith("pdf")] if xml_file: xml_path = os.path.join(final_extract_dir, xml_file[0]) metadata_entry = self.process_xml_metadata(xml_path, pmc_id) else: print(f"No XML file found in TGZ for PMCID: {pmc_id}") print(f'Skipping article') if pdf_path: pdf_path = os.path.join(final_extract_dir, pdf_path[0]) document = self.download_and_process_pdf(pdf_path, pmc_id, self.settings.AWS_BUCKET_NAME) else: if metadata_entry.get('body_text') and metadata_entry['body_text'] != "N/A": document = Document( page_content=metadata_entry['body_text'], metadata=metadata_entry ) metadata_entry.pop("body_text") else: print(f'Body content and PDF both not found, hence skipping this PDF') document = None # Cleanup: Remove the downloaded TGZ file if os.path.exists(local_tgz_filename): os.remove(local_tgz_filename) print(f"Removed file: {local_tgz_filename}") if os.path.exists(final_extract_dir): shutil.rmtree(final_extract_dir) return metadata_entry, document except Exception as e: print(f"Cannot download TGZ file for {pmc_id} : ftp link : {ftp_link}") print(f"[ERROR] {str(e)}") return {}, None def extract_text_from_element(self, element): """ Recursively extract all text from an XML element and its children. Args: element (Element): XML element to extract text from. Returns: str: Concatenated text content of the element and its children. """ text_content = element.text or "" # Start with the element's own text for child in element: text_content += self.extract_text_from_element(child) # Recurse into children if child.tail: # Include any tail text after the child element text_content += child.tail return text_content.strip() def process_xml_metadata(self, xml_path, pmc_id): tree = ET.parse(xml_path) root = tree.getroot() # Extract metadata title_elem = root.find(".//article-title") title = title_elem.text if title_elem is not None else "No Title Available" # title = root.find(".//article-title").text if root.find(".//article-title") else "No Title Available" # abstract = root.find(".//abstract/p").text if root.find(".//abstract/p") else "No Abstract Available" # Abstract extraction abstract_elem = root.find(".//abstract/p") abstract = abstract_elem.text if abstract_elem is not None else "No Abstract Available" # doi = root.find(".//article-id[@pub-id-type='doi']").text if root.find(".//article-id[@pub-id-type='doi']") else "N/A" # DOI extraction doi_elem = root.find(".//article-id[@pub-id-type='doi']") doi = doi_elem.text if doi_elem is not None else "N/A" # authors = [f"{author.find('surname').text}, {author.find('given-names').text}" # for author in root.findall(".//contrib/name")] authors = [] for author in root.findall(".//contrib/name"): surname = author.find('surname') given_names = author.find('given-names') # Safely handle missing elements surname_text = surname.text if surname is not None else "Unknown Surname" given_names_text = given_names.text if given_names is not None else "Unknown Given Names" authors.append(f"{surname_text}, {given_names_text}") keywords = [kw.text for kw in root.findall(".//kwd")] # Extract publication date pub_date_node = root.find(".//pub-date") if pub_date_node is not None: month = pub_date_node.find("month").text if pub_date_node.find("month") is not None else "N/A" year = pub_date_node.find("year").text if pub_date_node.find("year") is not None else "N/A" pub_type = pub_date_node.attrib.get("pub-type", "N/A") publication_date = f"{year}-{month}" if month != "N/A" else year else: publication_date = "N/A" # Extract text content from body_node = root.find(".//body") body_text = "" if body_node is not None: body_text = self.extract_text_from_element(body_node) else: body_text = "N/A" # Save enriched metadata metadata_entry = { "pmc_id": pmc_id, "title": title, "abstract": abstract, "authors": authors, "keywords": keywords, "doi": doi, "source": f"https://pmc.ncbi.nlm.nih.gov/articles/{pmc_id}", "publication_date" : publication_date, "body_text" : body_text } return metadata_entry def download_and_process_pdf(self, pdf_path, pmc_id, bucket_name): try: pdf_reader = PdfReader(pdf_path) text = "".join(page.extract_text() for page in pdf_reader.pages) # Create document object document = Document( page_content=text, metadata={"source": f"s3://{bucket_name}/{pmc_id}.pdf"} ) return document except Exception as e: print(f"Error processing PDF for {pmc_id}: {e}") return None def process_and_save(self, bucket_name, metadata_file_key): # Load existing metadata from S3 try: metadata_content = self.s3_handler.download_string_from_s3(bucket_name, metadata_file_key) existing_metadata = json.loads(metadata_content) existing_ids = {record["pmc_id"] for record in existing_metadata} print(f"Found {len(existing_ids)} existing records in metadata.") except Exception as e: # If metadata file doesn't exist or is empty, initialize an empty list print(f"Could not load metadata: {e}. Assuming no existing records.") existing_metadata = [] existing_ids = set() resumption_token = None while True: root = self.fetch_records(resumption_token=resumption_token) print(f'len of records: {len(root.findall(".//record"))}') resumption = root.find(".//resumption") print(f'resumption token: {resumption}') for record in root.findall(".//record"): # print(f'first record: ') pmc_id = record.attrib.get("id") # print(f'[INFO] pmc id : {pmc_id}') if pmc_id in existing_ids: # print(f"Skipping already downloaded record: {pmc_id}") continue pdf_link = None ftp_link = None for link in record.findall("link"): if link.attrib.get("format") == "tgz": ftp_link = link.attrib.get("href") if link.attrib.get("format") == "pdf": pdf_link = link.attrib.get("href") print(f'[INFO] links found: pdf {pdf_link} and ftp {ftp_link}') metadata = { } # Process `tgz` first if available if ftp_link: metadata, document = self.download_and_process_tgz(ftp_link, pmc_id) # documents.append(document) if not document: # print(f'this document doesnt have content. continue .. ') continue self.split_and_embed([document], metadata) # Create document object existing_metadata.append(metadata) self.update_metadata_and_upload(existing_metadata, bucket_name , metadata_file_key) resumption = root.find(".//resumption") if resumption is not None: link = resumption.find("link") if link is not None: resumption_token = link.attrib.get("token", "").strip() if not resumption_token: print("No more tokens found, stopping pagination.") break else: print("No link found, stopping pagination.") break else: print("No resumption element, stopping pagination.") break def create_or_connect_index(index_name, dimension): pc = pinecone.Pinecone(settings.PINECONE_API_KEY) """Create or connect to existing Pinecone index""" spec = pinecone.ServerlessSpec( cloud=settings.CLOUD, region=settings.REGION ) print(f'all indexes: {pc.list_indexes()}') if index_name not in pc.list_indexes().names(): pc.create_index( name=index_name, dimension=dimension, metric='cosine', # You can use 'dotproduct' or other metrics if needed spec=spec ) return pc.Index(settings.INDEX_NAME) if __name__ == "__main__": """ #todo: add all args as argument parser #todo: like from and until date, and all variables #todo: add one variable like how many iterations we need to go """ # Load settings settings = get_settings() # Initialize S3 handler s3_handler = S3Handler() import pinecone pc_index = create_or_connect_index(settings.INDEX_NAME, settings.DIMENSIONS) # Create the downloader instance downloader = PubMedDownloader( s3_handler=s3_handler, pubmed_base_url=settings.PUBMED_BASE_URL, pinecone_index= pc_index, embedding_model=OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY) ) # Process and save downloader.process_and_save( bucket_name=settings.AWS_BUCKET_NAME, metadata_file_key="pubmed_metadata/metadata.json" )