sehatech-demo / core /download_dataset.py
larawehbe's picture
Upload folder using huggingface_hub
965ac15 verified
import os
import json
import shutil
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
import requests
import tarfile
from langchain.schema import Document
import hashlib
import xml.etree.ElementTree as ET
from urllib import request
from s3_utils import S3Handler
from config import get_settings
from PyPDF2 import PdfReader
class PubMedDownloader:
def __init__(self, s3_handler, pubmed_base_url, pinecone_index, embedding_model, from_date="2024-01-01", until_date="2024-11-01", limit=3):
self.s3_handler = s3_handler
self.settings = get_settings()
self.pubmed_base_url = pubmed_base_url
self.from_date = from_date
self.until_date = until_date
self.limit = limit
self.local_download_dir = "downloaded_pdfs"
os.makedirs(self.local_download_dir, exist_ok=True)
self.pinecone_index = pinecone_index # Pinecone index instance
self.embedding_model = embedding_model # Embedding model instance
def split_and_embed(self, documents, metadata_entry):
"""Split documents into chunks and embed them sequentially."""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.settings.CHUNK_SIZE,
chunk_overlap=self.settings.CHUNK_OVERLAP
)
chunks = text_splitter.split_documents(documents)
print(f'total chunks created: {len(chunks)}')
batch_size = 50
pmc_id = metadata_entry['pmc_id']
for batch_index in range(0, len(chunks), batch_size):
batch = chunks[batch_index: batch_index + batch_size]
print(f'len of batch: {len(batch)}')
try:
# Process a single batch
# Create ids for the batch
# ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))]
ids = [f"{pmc_id}_chunk_{batch_index}_{j}" for j in range(len(batch))]
print(f'len of ids: {len(ids)}')
print(f'id sample: {ids[0]}')
# Get texts and generate embeddings
texts = [doc.page_content for doc in batch]
print(f'len of texts: {len(texts)}')
embeddings = self.embedding_model.embed_documents(texts)
metadata = []
for doc in batch:
chunk_metadata = metadata_entry.copy() # Copy base metadata
chunk_metadata["text"] = doc.page_content # Add chunk-specific text
metadata.append(chunk_metadata)
# Create upsert batch
to_upsert = list(zip(ids, embeddings, metadata))
# Upsert to Pinecone
self.pinecone_index.upsert(vectors=to_upsert)
print(f"Successfully upserted {len(to_upsert)} chunks to Pinecone.")
except Exception as e:
print(f"Error processing batch {batch_index}: {e}")
def fetch_records(self, resumption_token=None):
"""
Fetch records from PubMed using optional resumptionToken.
Args:
resumption_token (str, optional): Token to resume fetching records. Defaults to None.
Returns:
ElementTree.Element: Parsed XML root of the API response.
"""
# Build the base URL
url = f"{self.pubmed_base_url}"
# Define parameters
params = {
"format" : "tgz"
}
# Add date range if provided
if self.from_date and self.until_date:
params["from"] = self.from_date
params["until"] = self.until_date
# Add resumptionToken if available
if resumption_token:
params["resumptionToken"] = resumption_token
print(f"Using resumption token: {resumption_token}")
# Make the request
response = requests.get(url, params=params)
response.raise_for_status() # Raise an error for bad HTTP responses
# Parse and return the XML content
return ET.fromstring(response.content)
def save_metadata_to_s3(self, metadata, bucket, key):
print(f"Saving metadata to S3: s3://{bucket}/{key}")
self.s3_handler.upload_string_to_s3(metadata, bucket, key)
def save_pdf_to_s3(self, local_filename, bucket, s3_key):
"""Upload PDF to S3 and then delete the local file."""
print(f"Uploading PDF to S3: s3://{bucket}/{s3_key}")
self.s3_handler.upload_file_to_s3(local_filename, bucket, s3_key)
# Delete the local file after upload
if os.path.exists(local_filename):
os.remove(local_filename)
print(f"Deleted local file: {local_filename}")
else:
print(f"File not found for deletion: {local_filename}")
def update_metadata_and_upload(self, metadata_entry, bucket_name, metadata_file_key):
"""Update metadata list with a new entry and upload it to S3 as JSON."""
# Add new entry to metadata
# Convert metadata to JSON and upload to S3
metadata_json = json.dumps(metadata_entry, indent=4)
self.s3_handler.upload_string_to_s3(metadata_json, bucket_name, metadata_file_key)
print(f"Updated metadata uploaded to s3://{bucket_name}/{metadata_file_key}")
def download_and_process_tgz(self, ftp_link, pmc_id):
try:
metadata_entry = {}
# Step 1: Download TGZ
local_tgz_filename = os.path.join(self.local_download_dir, f"{pmc_id}.tgz")
print(f"Downloading TGZ: {ftp_link} saving in {local_tgz_filename}")
request.urlretrieve(ftp_link, local_tgz_filename)
# Step 2: Extract TGZ into a temporary directory
temp_extract_dir = os.path.join(self.local_download_dir, f"{pmc_id}_temp")
os.makedirs(temp_extract_dir, exist_ok=True)
print(f"Temporary extract dir: {temp_extract_dir}")
with tarfile.open(local_tgz_filename, "r:gz") as tar:
tar.extractall(path=temp_extract_dir)
# Step 3: Handle Nested Structure (Move Contents to Target Directory)
final_extract_dir = os.path.join(self.local_download_dir, pmc_id)
os.makedirs(final_extract_dir, exist_ok=True)
# Check if the archive creates a single root directory (e.g., PMC8419487/)
extracted_items = os.listdir(temp_extract_dir)
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_extract_dir, extracted_items[0])):
# Move contents of the single folder to the final directory
nested_dir = os.path.join(temp_extract_dir, extracted_items[0])
for item in os.listdir(nested_dir):
shutil.move(os.path.join(nested_dir, item), final_extract_dir)
else:
# If no single root folder, move all files directly
for item in extracted_items:
shutil.move(os.path.join(temp_extract_dir, item), final_extract_dir)
print(f"Final extracted dir: {final_extract_dir}")
# Clean up the temporary extraction directory
shutil.rmtree(temp_extract_dir)
print(f"Temporary extract dir deleted: {temp_extract_dir}")
# Process the extracted files as before...
xml_file = [f for f in os.listdir(final_extract_dir) if f.endswith(".xml") or f.endswith(".nxml")]
pdf_path = [f for f in os.listdir(final_extract_dir) if f.endswith("pdf")]
if xml_file:
xml_path = os.path.join(final_extract_dir, xml_file[0])
metadata_entry = self.process_xml_metadata(xml_path, pmc_id)
else:
print(f"No XML file found in TGZ for PMCID: {pmc_id}")
print(f'Skipping article')
if pdf_path:
pdf_path = os.path.join(final_extract_dir, pdf_path[0])
document = self.download_and_process_pdf(pdf_path, pmc_id, self.settings.AWS_BUCKET_NAME)
else:
if metadata_entry.get('body_text') and metadata_entry['body_text'] != "N/A":
document = Document(
page_content=metadata_entry['body_text'], metadata=metadata_entry
)
metadata_entry.pop("body_text")
else:
print(f'Body content and PDF both not found, hence skipping this PDF')
document = None
# Cleanup: Remove the downloaded TGZ file
if os.path.exists(local_tgz_filename):
os.remove(local_tgz_filename)
print(f"Removed file: {local_tgz_filename}")
if os.path.exists(final_extract_dir):
shutil.rmtree(final_extract_dir)
return metadata_entry, document
except Exception as e:
print(f"Cannot download TGZ file for {pmc_id} : ftp link : {ftp_link}")
print(f"[ERROR] {str(e)}")
return {}, None
def extract_text_from_element(self, element):
"""
Recursively extract all text from an XML element and its children.
Args:
element (Element): XML element to extract text from.
Returns:
str: Concatenated text content of the element and its children.
"""
text_content = element.text or "" # Start with the element's own text
for child in element:
text_content += self.extract_text_from_element(child) # Recurse into children
if child.tail: # Include any tail text after the child element
text_content += child.tail
return text_content.strip()
def process_xml_metadata(self, xml_path, pmc_id):
tree = ET.parse(xml_path)
root = tree.getroot()
# Extract metadata
title_elem = root.find(".//article-title")
title = title_elem.text if title_elem is not None else "No Title Available"
# title = root.find(".//article-title").text if root.find(".//article-title") else "No Title Available"
# abstract = root.find(".//abstract/p").text if root.find(".//abstract/p") else "No Abstract Available"
# Abstract extraction
abstract_elem = root.find(".//abstract/p")
abstract = abstract_elem.text if abstract_elem is not None else "No Abstract Available"
# doi = root.find(".//article-id[@pub-id-type='doi']").text if root.find(".//article-id[@pub-id-type='doi']") else "N/A"
# DOI extraction
doi_elem = root.find(".//article-id[@pub-id-type='doi']")
doi = doi_elem.text if doi_elem is not None else "N/A"
# authors = [f"{author.find('surname').text}, {author.find('given-names').text}"
# for author in root.findall(".//contrib/name")]
authors = []
for author in root.findall(".//contrib/name"):
surname = author.find('surname')
given_names = author.find('given-names')
# Safely handle missing elements
surname_text = surname.text if surname is not None else "Unknown Surname"
given_names_text = given_names.text if given_names is not None else "Unknown Given Names"
authors.append(f"{surname_text}, {given_names_text}")
keywords = [kw.text for kw in root.findall(".//kwd")]
# Extract publication date
pub_date_node = root.find(".//pub-date")
if pub_date_node is not None:
month = pub_date_node.find("month").text if pub_date_node.find("month") is not None else "N/A"
year = pub_date_node.find("year").text if pub_date_node.find("year") is not None else "N/A"
pub_type = pub_date_node.attrib.get("pub-type", "N/A")
publication_date = f"{year}-{month}" if month != "N/A" else year
else:
publication_date = "N/A"
# Extract text content from <body>
body_node = root.find(".//body")
body_text = ""
if body_node is not None:
body_text = self.extract_text_from_element(body_node)
else:
body_text = "N/A"
# Save enriched metadata
metadata_entry = {
"pmc_id": pmc_id,
"title": title,
"abstract": abstract,
"authors": authors,
"keywords": keywords,
"doi": doi,
"source": f"https://pmc.ncbi.nlm.nih.gov/articles/{pmc_id}",
"publication_date" : publication_date,
"body_text" : body_text
}
return metadata_entry
def download_and_process_pdf(self, pdf_path, pmc_id, bucket_name):
try:
pdf_reader = PdfReader(pdf_path)
text = "".join(page.extract_text() for page in pdf_reader.pages)
# Create document object
document = Document(
page_content=text,
metadata={"source": f"s3://{bucket_name}/{pmc_id}.pdf"}
)
return document
except Exception as e:
print(f"Error processing PDF for {pmc_id}: {e}")
return None
def process_and_save(self, bucket_name, metadata_file_key):
# Load existing metadata from S3
try:
metadata_content = self.s3_handler.download_string_from_s3(bucket_name, metadata_file_key)
existing_metadata = json.loads(metadata_content)
existing_ids = {record["pmc_id"] for record in existing_metadata}
print(f"Found {len(existing_ids)} existing records in metadata.")
except Exception as e:
# If metadata file doesn't exist or is empty, initialize an empty list
print(f"Could not load metadata: {e}. Assuming no existing records.")
existing_metadata = []
existing_ids = set()
resumption_token = None
while True:
root = self.fetch_records(resumption_token=resumption_token)
print(f'len of records: {len(root.findall(".//record"))}')
resumption = root.find(".//resumption")
print(f'resumption token: {resumption}')
for record in root.findall(".//record"):
# print(f'first record: ')
pmc_id = record.attrib.get("id")
# print(f'[INFO] pmc id : {pmc_id}')
if pmc_id in existing_ids:
# print(f"Skipping already downloaded record: {pmc_id}")
continue
pdf_link = None
ftp_link = None
for link in record.findall("link"):
if link.attrib.get("format") == "tgz":
ftp_link = link.attrib.get("href")
if link.attrib.get("format") == "pdf":
pdf_link = link.attrib.get("href")
print(f'[INFO] links found: pdf {pdf_link} and ftp {ftp_link}')
metadata = { }
# Process `tgz` first if available
if ftp_link:
metadata, document = self.download_and_process_tgz(ftp_link, pmc_id)
# documents.append(document)
if not document:
# print(f'this document doesnt have content. continue .. ')
continue
self.split_and_embed([document], metadata)
# Create document object
existing_metadata.append(metadata)
self.update_metadata_and_upload(existing_metadata, bucket_name , metadata_file_key)
resumption = root.find(".//resumption")
if resumption is not None:
link = resumption.find("link")
if link is not None:
resumption_token = link.attrib.get("token", "").strip()
if not resumption_token:
print("No more tokens found, stopping pagination.")
break
else:
print("No link found, stopping pagination.")
break
else:
print("No resumption element, stopping pagination.")
break
def create_or_connect_index(index_name, dimension):
pc = pinecone.Pinecone(settings.PINECONE_API_KEY)
"""Create or connect to existing Pinecone index"""
spec = pinecone.ServerlessSpec(
cloud=settings.CLOUD,
region=settings.REGION
)
print(f'all indexes: {pc.list_indexes()}')
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=dimension,
metric='cosine', # You can use 'dotproduct' or other metrics if needed
spec=spec
)
return pc.Index(settings.INDEX_NAME)
if __name__ == "__main__":
"""
#todo: add all args as argument parser
#todo: like from and until date, and all variables
#todo: add one variable like how many iterations we need to go
"""
# Load settings
settings = get_settings()
# Initialize S3 handler
s3_handler = S3Handler()
import pinecone
pc_index = create_or_connect_index(settings.INDEX_NAME, settings.DIMENSIONS)
# Create the downloader instance
downloader = PubMedDownloader(
s3_handler=s3_handler,
pubmed_base_url=settings.PUBMED_BASE_URL,
pinecone_index= pc_index,
embedding_model=OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
)
# Process and save
downloader.process_and_save(
bucket_name=settings.AWS_BUCKET_NAME,
metadata_file_key="pubmed_metadata/metadata.json"
)