gravity-dev / app.py
broadfield-dev's picture
Update app.py
2ae0bde verified
raw
history blame
27.7 kB
import os
import io
import requests
import logging
import re
import json
import base64
from flask import Flask, request, render_template, jsonify, send_file, Response
from PyPDF2 import PdfReader, PdfWriter
import pytesseract
from pdf2image import convert_from_bytes
from PIL import Image
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer
from datetime import datetime
from numpy import dot
from numpy.linalg import norm
from huggingface_hub import HfApi, hf_hub_download
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import chromadb
from chromadb.utils import embedding_functions
import shutil
# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set cache, uploads, and output directories
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/app/cache"
os.environ["XDG_CACHE_HOME"] = "/app/cache"
UPLOADS_DIR = "/app/uploads"
PAGES_DIR = os.path.join(UPLOADS_DIR, "pages")
OUTPUT_DIR = "/app/output"
COMBINED_PDF_PATH = os.path.join(OUTPUT_DIR, "combined_output.pdf")
PROGRESS_JSON_PATH = os.path.join(OUTPUT_DIR, "progress_log.json")
CHROMA_DB_PATH = os.path.join(OUTPUT_DIR, "chromadb")
os.makedirs(PAGES_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
app = Flask(__name__)
# Hugging Face Hub configuration
HF_TOKEN = os.getenv("HF_TOKEN")
HF_DATASET_REPO = "broadfield-dev/pdf-ocr-dataset"
HF_API = HfApi()
# Tracking file for resuming
TRACKING_FILE = "/app/cache/processing_state.json"
# Load sentence transformer
try:
embedder = SentenceTransformer('all-MiniLM-L6-v2', cache_folder="/app/cache")
logger.info("SentenceTransformer loaded successfully")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer: {e}")
# Initialize TrOCR (CPU-only)
try:
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model.to("cpu").eval()
logger.info("TrOCR initialized successfully on CPU")
except Exception as e:
logger.error(f"Failed to initialize TrOCR: {e}")
trocr_model = None
trocr_processor = None
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
chroma_collection = chroma_client.get_or_create_collection(name="pdf_pages", embedding_function=sentence_transformer_ef)
# Load or initialize progress log
def load_progress_log(storage_mode):
if storage_mode == "hf":
try:
progress_file = hf_hub_download(repo_id=HF_DATASET_REPO, filename="progress_log.json", repo_type="dataset", token=HF_TOKEN)
with open(progress_file, "r") as f:
return json.load(f)
except Exception as e:
logger.info(f"No HF progress log found or error loading: {e}, initializing new log")
return {"urls": {}}
else: # local
if os.path.exists(PROGRESS_JSON_PATH):
with open(PROGRESS_JSON_PATH, "r") as f:
return json.load(f)
return {"urls": {}}
def save_progress_log(progress_log, storage_mode):
if storage_mode == "hf":
with open("/app/cache/progress_log.json", "w") as f:
json.dump(progress_log, f)
HF_API.upload_file(
path_or_fileobj="/app/cache/progress_log.json",
path_in_repo="progress_log.json",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
logger.info("Progress log updated in Hugging Face dataset")
else: # local
with open(PROGRESS_JSON_PATH, "w") as f:
json.dump(progress_log, f)
logger.info("Progress log updated locally")
# Tesseract OCR with bounding boxes
def ocr_with_tesseract(pdf_bytes, page_num):
try:
images = convert_from_bytes(pdf_bytes, first_page=page_num+1, last_page=page_num+1)
if not images:
logger.info(f"Page {page_num + 1} is blank")
return {"page_num": page_num + 1, "text": "Blank page", "words": []}
image = images[0]
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
text = pytesseract.image_to_string(image)
words = [
{"text": data["text"][i], "left": data["left"][i], "top": data["top"][i], "width": data["width"][i], "height": data["height"][i]}
for i in range(len(data["text"])) if data["text"][i].strip()
]
logger.info(f"Tesseract processed page {page_num + 1} with {len(words)} words")
return {"page_num": page_num + 1, "text": text, "words": words}
except Exception as e:
logger.error(f"Tesseract error on page {page_num + 1}: {e}")
return {"page_num": page_num + 1, "text": f"Tesseract Error: {str(e)}", "words": []}
# TrOCR OCR
def ocr_with_trocr(pdf_bytes, page_num):
if not trocr_model or not trocr_processor:
logger.warning(f"TrOCR not available for page {page_num + 1}")
return {"page_num": page_num + 1, "text": "TrOCR not initialized", "words": []}
try:
images = convert_from_bytes(pdf_bytes, first_page=page_num+1, last_page=page_num+1)
if not images:
logger.info(f"Page {page_num + 1} is blank")
return {"page_num": page_num + 1, "text": "Blank page", "words": []}
image = images[0].convert("RGB")
pixel_values = trocr_processor(image, return_tensors="pt").pixel_values.to("cpu")
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values, max_length=50)
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
words = [{"text": word, "left": 0, "top": 0, "width": 0, "height": 0} for word in text.split()]
logger.info(f"TrOCR processed page {page_num + 1} with text: {text}")
return {"page_num": page_num + 1, "text": text, "words": words}
except Exception as e:
logger.error(f"TrOCR error on page {page_num + 1}: {e}")
return {"page_num": page_num + 1, "text": f"TrOCR Error: {str(e)}", "words": []}
# Map Tesseract bounding boxes to OCR text
def map_tesseract_to_ocr(tesseract_result, ocr_result):
if not tesseract_result["words"] or "Error" in ocr_result["text"]:
logger.info(f"Mapping skipped for page {tesseract_result['page_num']}: No Tesseract words or OCR error")
return {**ocr_result, "words": tesseract_result["words"]}
ocr_text = ocr_result["text"]
tesseract_words = tesseract_result["words"]
sentences = re.split(r'(?<=[.!?])\s+', ocr_text.strip())
sentence_embeddings = embedder.encode(sentences)
mapped_words = []
for word in tesseract_words:
word_embedding = embedder.encode(word["text"])
similarities = [
dot(word_embedding, sent_emb) / (norm(word_embedding) * norm(sent_emb)) if norm(sent_emb) != 0 else 0
for sent_emb in sentence_embeddings
]
best_match_idx = similarities.index(max(similarities))
best_sentence = sentences[best_match_idx]
if word["text"].lower() in best_sentence.lower():
mapped_words.append(word)
else:
mapped_words.append(word)
logger.info(f"Mapped {len(mapped_words)} words for page {tesseract_result['page_num']}")
return {**ocr_result, "words": mapped_words}
# Update combined PDF
def update_combined_pdf(pdf_bytes, page_num):
pdf_reader = PdfReader(io.BytesIO(pdf_bytes))
page = pdf_reader.pages[page_num]
writer = PdfWriter()
if os.path.exists(COMBINED_PDF_PATH):
existing_pdf = PdfReader(COMBINED_PDF_PATH)
for p in existing_pdf.pages:
writer.add_page(p)
writer.add_page(page)
with open(COMBINED_PDF_PATH, "wb") as f:
writer.write(f)
logger.info(f"Updated combined PDF with page {page_num + 1}")
# Process page
def process_page(pdf_bytes, page_num, ocr_backend, filename, tracking_state, storage_mode):
tesseract_result = ocr_with_tesseract(pdf_bytes, page_num)
ocr_result = ocr_with_trocr(pdf_bytes, page_num) if ocr_backend == "trocr" else ocr_with_tesseract(pdf_bytes, page_num)
combined_result = map_tesseract_to_ocr(tesseract_result, ocr_result)
local_page_path = os.path.join(PAGES_DIR, f"{filename}_page_{combined_result['page_num']}_{datetime.now().strftime('%Y%m%d%H%M%S')}.pdf")
writer = PdfWriter()
pdf_reader = PdfReader(io.BytesIO(pdf_bytes))
writer.add_page(pdf_reader.pages[page_num])
with open(local_page_path, "wb") as f:
writer.write(f)
if storage_mode == "hf":
remote_page_path = f"pages/{os.path.basename(local_page_path)}"
HF_API.upload_file(
path_or_fileobj=local_page_path,
path_in_repo=remote_page_path,
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
logger.info(f"Uploaded page to {HF_DATASET_REPO}/{remote_page_path}")
combined_result["page_file"] = remote_page_path
else: # local
update_combined_pdf(pdf_bytes, page_num)
combined_result["page_file"] = local_page_path
combined_result["pdf_page"] = tracking_state["last_offset"] + page_num
# Update ChromaDB
chroma_collection.add(
documents=[combined_result["text"]],
metadatas=[{"filename": filename, "page_num": combined_result["page_num"], "page_file": combined_result["page_file"], "words": json.dumps(combined_result["words"])}],
ids=[f"{filename}_page_{combined_result['page_num']}"]
)
logger.info(f"Added page {combined_result['page_num']} to ChromaDB")
return combined_result
# Extract PDF URLs from text
def extract_pdf_urls(text):
url_pattern = r'(https?://[^\s]+?\.pdf)'
return re.findall(url_pattern, text)
# Load or initialize tracking state
def load_tracking_state():
if os.path.exists(TRACKING_FILE):
with open(TRACKING_FILE, "r") as f:
return json.load(f)
return {"processed_urls": {}, "last_offset": 0}
def save_tracking_state(state):
with open(TRACKING_FILE, "w") as f:
json.dump(state, f)
# Push to Hugging Face Dataset
def push_to_hf_dataset(new_data):
try:
for item in new_data:
if "url" not in item or not isinstance(item["url"], str):
logger.error(f"Invalid item in new_data: {item}")
raise ValueError(f"Each item must have a valid 'url' key; found {item}")
try:
dataset = load_dataset(HF_DATASET_REPO, token=HF_TOKEN, cache_dir="/app/cache")
existing_data = dataset["train"].to_dict()
logger.info(f"Loaded existing dataset with keys: {list(existing_data.keys())}")
except Exception as e:
logger.info(f"No existing dataset found or error loading: {e}, initializing new dataset")
existing_data = {"filename": [], "pages": [], "url": [], "embedding": [], "processed_at": [], "pdf_page_offset": []}
required_keys = ["filename", "pages", "url", "embedding", "processed_at", "pdf_page_offset"]
for key in required_keys:
if key not in existing_data:
existing_data[key] = []
logger.warning(f"Initialized missing key '{key}' in existing_data")
existing_urls = set(existing_data["url"])
for item in new_data:
logger.debug(f"Processing item: {item}")
if item["url"] not in existing_urls:
for key in required_keys:
existing_data[key].append(item.get(key, None))
existing_urls.add(item["url"])
logger.info(f"Added new URL: {item['url']}")
else:
idx = existing_data["url"].index(item["url"])
existing_data["pages"][idx].extend(item["pages"])
existing_data["embedding"][idx] = item["embedding"]
existing_data["processed_at"][idx] = item["processed_at"]
logger.info(f"Updated existing URL: {item['url']}")
updated_dataset = Dataset.from_dict(existing_data)
updated_dataset.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
logger.info(f"Successfully appended/updated {len(new_data)} records to {HF_DATASET_REPO}")
except Exception as e:
logger.error(f"Failed to push to HF Dataset: {str(e)}")
raise
# Check if URL is fully processed
def is_url_fully_processed(url, progress_log, total_pages):
return url in progress_log["urls"] and progress_log["urls"][url]["status"] == "completed" and progress_log["urls"][url]["processed_pages"] >= total_pages
# Process PDF URL with SSE
def process_pdf_url(url, ocr_backend, tracking_state, progress_log, storage_mode):
filename = url.split("/")[-1]
try:
yield f"data: {json.dumps({'status': 'fetching', 'filename': filename})}\n\n"
logger.info(f"Fetching PDF from {url}")
response = requests.get(url, timeout=10)
response.raise_for_status()
pdf_bytes = response.content
pdf_reader = PdfReader(io.BytesIO(pdf_bytes))
total_pages = len(pdf_reader.pages)
progress_log["urls"].setdefault(url, {"status": "pending", "processed_pages": 0})
start_page = progress_log["urls"][url]["processed_pages"]
if is_url_fully_processed(url, progress_log, total_pages):
yield f"data: {json.dumps({'status': 'skipped', 'filename': filename, 'message': 'URL already fully processed'})}\n\n"
return
pages = []
for page_num in range(start_page, total_pages):
yield f"data: {json.dumps({'status': 'processing', 'filename': filename, 'page_num': page_num + 1, 'total_pages': total_pages})}\n\n"
page = process_page(pdf_bytes, page_num, ocr_backend, filename, tracking_state, storage_mode)
pages.append(page)
yield f"data: {json.dumps({'filename': filename, 'page': page})}\n\n"
progress_log["urls"][url]["processed_pages"] = page_num + 1
save_progress_log(progress_log, storage_mode)
full_text = "\n\n".join(f"Page {page['page_num']}\n{page['text']}" for page in pages)
embedding = embedder.encode(full_text).tolist() if full_text.strip() else None
result = {
"filename": filename,
"pages": pages,
"url": url,
"embedding": embedding,
"processed_at": datetime.now().isoformat(),
"pdf_page_offset": tracking_state["last_offset"]
}
if storage_mode == "hf":
push_to_hf_dataset([result])
tracking_state["last_offset"] += total_pages - start_page
progress_log["urls"][url]["status"] = "completed"
save_tracking_state(tracking_state)
save_progress_log(progress_log, storage_mode)
yield f"data: {json.dumps({'status': 'completed', 'filename': filename, 'new_offset': tracking_state['last_offset']})}\n\n"
logger.info(f"Completed processing {filename} with new offset {tracking_state['last_offset']}")
except requests.RequestException as e:
logger.error(f"Failed to fetch PDF from {url}: {e}")
yield f"data: {json.dumps({'status': 'error', 'filename': filename, 'message': f'Error fetching PDF: {str(e)}'})}\n\n"
except Exception as e:
logger.error(f"Error processing {url}: {e}")
yield f"data: {json.dumps({'status': 'error', 'filename': filename, 'message': f'Error: {str(e)}'})}\n\n"
# Process text content with SSE
def process_text_content(text, filename, ocr_backend, tracking_state, progress_log, storage_mode):
try:
pdf_urls = extract_pdf_urls(text)
processed_urls = [url for url in pdf_urls if url in progress_log["urls"] and progress_log["urls"][url]["status"] == "completed"]
new_urls = [url for url in pdf_urls if url not in progress_log["urls"] or progress_log["urls"][url]["status"] != "completed"]
initial_text = (f"Found {len(pdf_urls)} PDF URLs:\n" +
f"Already processed: {len(processed_urls)}\n" + "\n".join(processed_urls) + "\n" +
f"To process: {len(new_urls)}\n" + "\n".join(new_urls) + "\n\nProcessing...")
yield f"data: {json.dumps({'status': 'info', 'filename': filename, 'message': initial_text})}\n\n"
for url in new_urls:
logger.info(f"Starting processing of {url} with offset {tracking_state['last_offset']}")
for event in process_pdf_url(url, ocr_backend, tracking_state, progress_log, storage_mode):
yield event
except Exception as e:
logger.error(f"Error processing text content for {filename}: {e}")
yield f"data: {json.dumps({'status': 'error', 'filename': filename, 'message': f'Error: {str(e)}'})}\n\n"
# Home route
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
# Process URL endpoint with GET
@app.route("/process_url", methods=["GET"])
def process_url():
url = request.args.get("url")
ocr_backend = request.args.get("ocr_backend", "trocr")
storage_mode = request.args.get("storage_mode", "hf")
if not url:
return jsonify({"error": "No URL provided"}), 400
tracking_state = load_tracking_state()
progress_log = load_progress_log(storage_mode)
def generate():
logger.info(f"Processing URL: {url} with ocr_backend={ocr_backend}, storage_mode={storage_mode}, starting offset={tracking_state['last_offset']}")
if url.endswith(".pdf"):
for event in process_pdf_url(url, ocr_backend, tracking_state, progress_log, storage_mode):
yield event
elif url.endswith(".txt"):
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
text = response.text
filename = url.split("/")[-1]
logger.info(f"Fetched text from {url}")
for event in process_text_content(text, filename, ocr_backend, tracking_state, progress_log, storage_mode):
yield event
except requests.RequestException as e:
logger.error(f"Failed to fetch text from {url}: {e}")
yield f"data: {json.dumps({'status': 'error', 'filename': url, 'message': f'Error fetching URL: {str(e)}'})}\n\n"
else:
yield f"data: {json.dumps({'status': 'error', 'filename': url, 'message': 'Unsupported URL format. Must end in .pdf or .txt'})}\n\n"
logger.info(f"Finished processing URL: {url}")
return Response(generate(), mimetype="text/event-stream")
# Search page
@app.route("/search", methods=["GET"])
def search_page():
storage_mode = request.args.get("storage_mode", "hf")
if storage_mode == "hf":
try:
dataset = load_dataset(HF_DATASET_REPO, token=HF_TOKEN, cache_dir="/app/cache")["train"]
files = [{"filename": f, "url": u, "pages": p} for f, u, p in zip(dataset["filename"], dataset["url"], dataset["pages"])]
return render_template("search.html", files=files, storage_mode=storage_mode)
except Exception as e:
logger.error(f"Error loading search page: {e}")
return render_template("search.html", files=[], error=str(e), storage_mode=storage_mode)
else: # local
files = []
results = chroma_collection.get()
for i, metadata in enumerate(results["metadatas"]):
files.append({
"filename": metadata["filename"],
"url": "",
"pages": [{"page_num": metadata["page_num"], "text": results["documents"][i], "page_file": metadata["page_file"], "words": json.loads(metadata["words"])}]
})
return render_template("search.html", files=files, storage_mode=storage_mode)
# Semantic search route
@app.route("/search_documents", methods=["POST"])
def search_documents():
query = request.form.get("query")
storage_mode = request.form.get("storage_mode", "hf")
if not query:
return jsonify({"error": "No query provided"}), 400
if storage_mode == "hf":
try:
dataset = load_dataset(HF_DATASET_REPO, token=HF_TOKEN, cache_dir="/app/cache")["train"]
query_embedding = embedder.encode(query).tolist()
embeddings = [e for e in dataset["embedding"] if e is not None]
documents = dataset["pages"]
filenames = dataset["filename"]
urls = dataset["url"]
processed_ats = dataset["processed_at"]
pdf_page_offsets = dataset["pdf_page_offset"]
similarities = [
dot(query_embedding, emb) / (norm(query_embedding) * norm(emb)) if norm(emb) != 0 else 0
for emb in embeddings
]
sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:5]
results = []
for idx, i in enumerate(sorted_indices):
pages = documents[i]
highlighted_pages = []
for page in pages:
words = page["words"]
text = page["text"]
pdf_page_num = page["pdf_page"]
page_file = page["page_file"]
page_url = f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{page_file}"
response = requests.get(page_url)
response.raise_for_status()
pdf_bytes = response.content
pdf_base64 = base64.b64encode(pdf_bytes).decode('utf-8')
sentences = re.split(r'(?<=[.!?])\s+', text)
highlights = []
for sent_idx, sentence in enumerate(sentences):
sent_embedding = embedder.encode(sentence).tolist()
similarity = dot(query_embedding, sent_embedding) / (norm(query_embedding) * norm(sent_embedding)) if norm(sent_embedding) != 0 else 0
if similarity > 0.7:
matching_words = []
sent_words = sentence.split()
word_idx = 0
for word in words:
if word_idx < len(sent_words) and word["text"].lower() in sent_words[word_idx].lower():
matching_words.append(word)
word_idx += 1
highlights.append({"sentence": sentence, "index": sent_idx, "words": matching_words})
highlighted_pages.append({
"page_num": page["page_num"],
"text": text,
"highlights": highlights,
"pdf_page": pdf_page_num,
"pdf_data": pdf_base64,
"page_url": page_url
})
results.append({
"filename": filenames[i],
"pages": highlighted_pages,
"url": urls[i],
"processed_at": processed_ats[i],
"similarity": similarities[i],
"pdf_page_offset": pdf_page_offsets[i]
})
return jsonify({"results": results})
except Exception as e:
logger.error(f"Search error: {e}")
return jsonify({"error": str(e)}), 500
else: # local with ChromaDB
try:
query_results = chroma_collection.query(query_texts=[query], n_results=5)
results = []
for i, doc in enumerate(query_results["documents"][0]):
metadata = query_results["metadatas"][0][i]
words = json.loads(metadata["words"])
text = doc
sentences = re.split(r'(?<=[.!?])\s+', text)
highlights = []
query_embedding = embedder.encode(query).tolist()
for sent_idx, sentence in enumerate(sentences):
sent_embedding = embedder.encode(sentence).tolist()
similarity = dot(query_embedding, sent_embedding) / (norm(query_embedding) * norm(sent_embedding)) if norm(sent_embedding) != 0 else 0
if similarity > 0.7:
matching_words = []
sent_words = sentence.split()
word_idx = 0
for word in words:
if word_idx < len(sent_words) and word["text"].lower() in sent_words[word_idx].lower():
matching_words.append(word)
word_idx += 1
highlights.append({"sentence": sentence, "index": sent_idx, "words": matching_words})
with open(metadata["page_file"], "rb") as f:
pdf_bytes = f.read()
pdf_base64 = base64.b64encode(pdf_bytes).decode('utf-8')
results.append({
"filename": metadata["filename"],
"pages": [{
"page_num": metadata["page_num"],
"text": text,
"highlights": highlights,
"pdf_page": metadata["page_num"],
"pdf_data": pdf_base64,
"page_url": metadata["page_file"]
}],
"url": "",
"processed_at": datetime.now().isoformat(),
"similarity": query_results["distances"][0][i]
})
return jsonify({"results": results})
except Exception as e:
logger.error(f"ChromaDB search error: {e}")
return jsonify({"error": str(e)}), 500
# Download output folder
@app.route("/download_output", methods=["GET"])
def download_output():
try:
zip_path = "/app/output.zip"
shutil.make_archive("/app/output", "zip", OUTPUT_DIR)
return send_file(zip_path, download_name="output.zip", as_attachment=True, mimetype="application/zip")
except Exception as e:
logger.error(f"Error creating zip: {e}")
return jsonify({"error": str(e)}), 500
# Preview output contents
@app.route("/preview_output", methods=["GET"])
def preview_output():
try:
combined_pdf_base64 = ""
if os.path.exists(COMBINED_PDF_PATH):
with open(COMBINED_PDF_PATH, "rb") as f:
combined_pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
progress_json = {}
if os.path.exists(PROGRESS_JSON_PATH):
with open(PROGRESS_JSON_PATH, "r") as f:
progress_json = json.load(f)
return jsonify({
"combined_pdf": combined_pdf_base64,
"progress_json": progress_json
})
except Exception as e:
logger.error(f"Error previewing output: {e}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=True)