|
import gradio as gr |
|
import base64 |
|
import requests |
|
import json |
|
import re |
|
import os |
|
import uuid |
|
from datetime import datetime |
|
|
|
|
|
|
|
OPENROUTER_API_KEY = "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" |
|
IMAGE_MODEL = "opengvlab/internvl3-14b:free" |
|
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions" |
|
|
|
|
|
|
|
|
|
processed_files_data = [] |
|
person_profiles = {} |
|
|
|
|
|
|
|
def extract_json_from_text(text): |
|
""" |
|
Extracts a JSON object from a string, trying common markdown and direct JSON. |
|
""" |
|
if not text: |
|
return {"error": "Empty text provided for JSON extraction."} |
|
|
|
|
|
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE) |
|
if match_block: |
|
json_str = match_block.group(1) |
|
else: |
|
|
|
text_stripped = text.strip() |
|
if text_stripped.startswith("`") and text_stripped.endswith("`"): |
|
json_str = text_stripped[1:-1] |
|
else: |
|
json_str = text_stripped |
|
|
|
try: |
|
return json.loads(json_str) |
|
except json.JSONDecodeError as e: |
|
|
|
try: |
|
first_brace = json_str.find('{') |
|
last_brace = json_str.rfind('}') |
|
if first_brace != -1 and last_brace != -1 and last_brace > first_brace: |
|
potential_json_str = json_str[first_brace : last_brace+1] |
|
return json.loads(potential_json_str) |
|
else: |
|
return {"error": f"Invalid JSON structure: {str(e)}", "original_text": text} |
|
except json.JSONDecodeError as e2: |
|
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text} |
|
|
|
|
|
def get_ocr_prompt(): |
|
return f"""You are an advanced OCR and information extraction AI. |
|
Your task is to meticulously analyze this image and extract all relevant information. |
|
|
|
Output Format Instructions: |
|
Provide your response as a SINGLE, VALID JSON OBJECT. Do not include any explanatory text before or after the JSON. |
|
The JSON object should have the following top-level keys: |
|
- "document_type_detected": (string) Your best guess of the specific document type (e.g., "Passport", "National ID Card", "Driver's License", "Visa Sticker", "Hotel Confirmation Voucher", "Bank Statement", "Photo of a person"). |
|
- "extracted_fields": (object) A key-value map of all extracted information. Be comprehensive. Examples: |
|
- For passports/IDs: "Surname", "Given Names", "Full Name", "Document Number", "Nationality", "Date of Birth", "Sex", "Place of Birth", "Date of Issue", "Date of Expiry", "Issuing Authority", "Country Code". |
|
- For hotel reservations: "Guest Name", "Hotel Name", "Booking Reference", "Check-in Date", "Check-out Date". |
|
- For bank statements: "Account Holder Name", "Account Number", "Bank Name", "Statement Period", "Ending Balance". |
|
- For photos: "Description" (e.g., "Portrait of a person", "Group photo at a location"), "People Present" (array of strings if multiple). |
|
- "mrz_data": (object or null) If a Machine Readable Zone (MRZ) is present: |
|
- "raw_mrz_lines": (array of strings) Each line of the MRZ. |
|
- "parsed_mrz": (object) Key-value pairs of parsed MRZ fields. |
|
If no MRZ, this field should be null. |
|
- "full_text_ocr": (string) Concatenation of all text found on the document. |
|
|
|
Extraction Guidelines: |
|
1. Prioritize accuracy. |
|
2. Extract all visible text. Include "Full Name" by combining given and surnames if possible. |
|
3. For dates, try to use ISO 8601 format (YYYY-MM-DD) if possible, but retain original format if conversion is ambiguous. |
|
|
|
Ensure the entire output strictly adheres to the JSON format. |
|
""" |
|
|
|
def call_openrouter_ocr(image_filepath): |
|
if not OPENROUTER_API_KEY: |
|
return {"error": "OpenRouter API Key not configured."} |
|
try: |
|
with open(image_filepath, "rb") as f: |
|
encoded_image = base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
mime_type = "image/jpeg" |
|
if image_filepath.lower().endswith(".png"): |
|
mime_type = "image/png" |
|
elif image_filepath.lower().endswith(".webp"): |
|
mime_type = "image/webp" |
|
|
|
data_url = f"data:{mime_type};base64,{encoded_image}" |
|
prompt_text = get_ocr_prompt() |
|
|
|
payload = { |
|
"model": IMAGE_MODEL, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt_text}, |
|
{"type": "image_url", "image_url": {"url": data_url}} |
|
] |
|
} |
|
], |
|
"max_tokens": 3500, |
|
"temperature": 0.1, |
|
} |
|
headers = { |
|
"Authorization": f"Bearer {OPENROUTER_API_KEY}", |
|
"Content-Type": "application/json", |
|
"HTTP-Referer": "https://huggingface.co/spaces/DoClassifier", |
|
"X-Title": "DoClassifier Processor" |
|
} |
|
|
|
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180) |
|
response.raise_for_status() |
|
result = response.json() |
|
|
|
if "choices" in result and result["choices"]: |
|
raw_content = result["choices"][0]["message"]["content"] |
|
return extract_json_from_text(raw_content) |
|
else: |
|
return {"error": "No 'choices' in API response from OpenRouter.", "details": result} |
|
|
|
except requests.exceptions.Timeout: |
|
return {"error": "API request timed out."} |
|
except requests.exceptions.RequestException as e: |
|
error_message = f"API Request Error: {str(e)}" |
|
if hasattr(e, 'response') and e.response is not None: |
|
error_message += f" Status: {e.response.status_code}, Response: {e.response.text}" |
|
return {"error": error_message} |
|
except Exception as e: |
|
return {"error": f"An unexpected error occurred during OCR: {str(e)}"} |
|
|
|
def extract_entities_from_ocr(ocr_json): |
|
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json["extracted_fields"], dict): |
|
return {"name": None, "dob": None, "passport_no": None, "doc_type": ocr_json.get("document_type_detected", "Unknown")} |
|
|
|
fields = ocr_json["extracted_fields"] |
|
doc_type = ocr_json.get("document_type_detected", "Unknown") |
|
|
|
|
|
name_keys = ["full name", "name", "account holder name", "guest name"] |
|
dob_keys = ["date of birth", "dob"] |
|
passport_keys = ["document number", "passport number"] |
|
|
|
extracted_name = None |
|
for key in name_keys: |
|
for field_key, value in fields.items(): |
|
if key == field_key.lower(): |
|
extracted_name = str(value) if value else None |
|
break |
|
if extracted_name: |
|
break |
|
|
|
extracted_dob = None |
|
for key in dob_keys: |
|
for field_key, value in fields.items(): |
|
if key == field_key.lower(): |
|
extracted_dob = str(value) if value else None |
|
break |
|
if extracted_dob: |
|
break |
|
|
|
extracted_passport_no = None |
|
for key in passport_keys: |
|
for field_key, value in fields.items(): |
|
if key == field_key.lower(): |
|
extracted_passport_no = str(value).replace(" ", "").upper() if value else None |
|
break |
|
if extracted_passport_no: |
|
break |
|
|
|
return { |
|
"name": extracted_name, |
|
"dob": extracted_dob, |
|
"passport_no": extracted_passport_no, |
|
"doc_type": doc_type |
|
} |
|
|
|
def normalize_name(name): |
|
if not name: return "" |
|
return "".join(filter(str.isalnum, name)).lower() |
|
|
|
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data): |
|
""" |
|
Tries to assign a document to an existing person or creates a new one. |
|
Returns a person_key. |
|
Updates current_persons_data in place. |
|
""" |
|
passport_no = entities.get("passport_no") |
|
name = entities.get("name") |
|
dob = entities.get("dob") |
|
|
|
|
|
if passport_no: |
|
for p_key, p_data in current_persons_data.items(): |
|
if passport_no in p_data.get("passport_numbers", set()): |
|
p_data["doc_ids"].add(doc_id) |
|
|
|
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name |
|
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob |
|
return p_key |
|
|
|
new_person_key = f"person_{passport_no}" |
|
current_persons_data[new_person_key] = { |
|
"canonical_name": name, |
|
"canonical_dob": dob, |
|
"names": {normalize_name(name)} if name else set(), |
|
"dobs": {dob} if dob else set(), |
|
"passport_numbers": {passport_no}, |
|
"doc_ids": {doc_id}, |
|
"display_name": name or f"Person (ID: {passport_no})" |
|
} |
|
return new_person_key |
|
|
|
|
|
if name and dob: |
|
norm_name = normalize_name(name) |
|
composite_key_nd = f"{norm_name}_{dob}" |
|
for p_key, p_data in current_persons_data.items(): |
|
|
|
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()): |
|
p_data["doc_ids"].add(doc_id) |
|
return p_key |
|
|
|
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}" |
|
current_persons_data[new_person_key] = { |
|
"canonical_name": name, |
|
"canonical_dob": dob, |
|
"names": {norm_name}, |
|
"dobs": {dob}, |
|
"passport_numbers": set(), |
|
"doc_ids": {doc_id}, |
|
"display_name": name |
|
} |
|
return new_person_key |
|
|
|
|
|
if name: |
|
norm_name = normalize_name(name) |
|
|
|
|
|
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}" |
|
current_persons_data[new_person_key] = { |
|
"canonical_name": name, "canonical_dob": None, |
|
"names": {norm_name}, "dobs": set(), "passport_numbers": set(), |
|
"doc_ids": {doc_id}, "display_name": name |
|
} |
|
return new_person_key |
|
|
|
|
|
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}" |
|
current_persons_data[generic_person_key] = { |
|
"canonical_name": "Unknown", "canonical_dob": None, |
|
"names": set(), "dobs": set(), "passport_numbers": set(), |
|
"doc_ids": {doc_id}, "display_name": f"Unknown Person ({doc_id[:6]})" |
|
} |
|
return generic_person_key |
|
|
|
|
|
def format_dataframe_data(current_files_data): |
|
|
|
|
|
df_rows = [] |
|
for f_data in current_files_data: |
|
entities = f_data.get("entities") or {} |
|
df_rows.append([ |
|
f_data["doc_id"][:8], |
|
f_data["filename"], |
|
f_data["status"], |
|
entities.get("doc_type", "N/A"), |
|
entities.get("name", "N/A"), |
|
entities.get("dob", "N/A"), |
|
entities.get("passport_no", "N/A"), |
|
f_data.get("assigned_person_key", "N/A") |
|
]) |
|
return df_rows |
|
|
|
def format_persons_markdown(current_persons_data, current_files_data): |
|
if not current_persons_data: |
|
return "No persons identified yet." |
|
|
|
md_parts = ["## Classified Persons & Documents\n"] |
|
for p_key, p_data in current_persons_data.items(): |
|
display_name = p_data.get('display_name', p_key) |
|
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})") |
|
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}") |
|
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}") |
|
|
|
md_parts.append("* Documents:") |
|
doc_ids_for_person = p_data.get("doc_ids", set()) |
|
if doc_ids_for_person: |
|
for doc_id in doc_ids_for_person: |
|
|
|
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None) |
|
if doc_detail: |
|
filename = doc_detail["filename"] |
|
doc_type = doc_detail.get("entities", {}).get("doc_type", "Unknown Type") |
|
md_parts.append(f" - {filename} (`{doc_type}`)") |
|
else: |
|
md_parts.append(f" - Document ID: {doc_id[:8]} (details not found, unexpected)") |
|
else: |
|
md_parts.append(" - No documents currently assigned.") |
|
md_parts.append("\n---\n") |
|
return "\n".join(md_parts) |
|
|
|
|
|
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)): |
|
global processed_files_data, person_profiles |
|
processed_files_data = [] |
|
person_profiles = {} |
|
|
|
if not OPENROUTER_API_KEY: |
|
yield ( |
|
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]], |
|
"Error: OpenRouter API Key not configured. Please set it in Space Secrets.", |
|
"{}", "API Key Missing. Processing halted." |
|
) |
|
return |
|
|
|
if not files_list: |
|
yield ([], "No files uploaded.", "{}", "Upload files to begin.") |
|
return |
|
|
|
|
|
for i, file_obj in enumerate(files_list): |
|
doc_uid = str(uuid.uuid4()) |
|
processed_files_data.append({ |
|
"doc_id": doc_uid, |
|
"filename": os.path.basename(file_obj.name), |
|
"filepath": file_obj.name, |
|
"status": "Queued", |
|
"ocr_json": None, |
|
"entities": None, |
|
"assigned_person_key": None |
|
}) |
|
|
|
initial_df_data = format_dataframe_data(processed_files_data) |
|
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data) |
|
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.") |
|
|
|
|
|
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")): |
|
current_doc_id = file_data_item["doc_id"] |
|
current_filename = file_data_item["filename"] |
|
|
|
|
|
file_data_item["status"] = "OCR in Progress..." |
|
df_data = format_dataframe_data(processed_files_data) |
|
persons_md = format_persons_markdown(person_profiles, processed_files_data) |
|
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}") |
|
|
|
ocr_result = call_openrouter_ocr(file_data_item["filepath"]) |
|
file_data_item["ocr_json"] = ocr_result |
|
|
|
if "error" in ocr_result: |
|
file_data_item["status"] = f"OCR Error: {ocr_result['error'][:50]}..." |
|
df_data = format_dataframe_data(processed_files_data) |
|
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}") |
|
continue |
|
|
|
file_data_item["status"] = "OCR Done. Extracting Entities..." |
|
df_data = format_dataframe_data(processed_files_data) |
|
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}") |
|
|
|
|
|
entities = extract_entities_from_ocr(ocr_result) |
|
file_data_item["entities"] = entities |
|
file_data_item["status"] = "Entities Extracted. Classifying..." |
|
df_data = format_dataframe_data(processed_files_data) |
|
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}") |
|
|
|
|
|
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles) |
|
file_data_item["assigned_person_key"] = person_key |
|
file_data_item["status"] = "Classified" |
|
|
|
df_data = format_dataframe_data(processed_files_data) |
|
persons_md = format_persons_markdown(person_profiles, processed_files_data) |
|
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}") |
|
|
|
final_df_data = format_dataframe_data(processed_files_data) |
|
final_persons_md = format_persons_markdown(person_profiles, processed_files_data) |
|
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.") |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# 📄 Intelligent Document Processor & Classifier") |
|
gr.Markdown( |
|
"**Upload multiple documents (images of passports, bank statements, hotel reservations, photos, etc.). " |
|
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n" |
|
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space." |
|
) |
|
|
|
if not OPENROUTER_API_KEY: |
|
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath") |
|
process_button = gr.Button("Process Uploaded Documents", variant="primary") |
|
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("## Document Processing Details") |
|
|
|
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"] |
|
document_status_df = gr.Dataframe( |
|
headers=dataframe_headers, |
|
datatype=["str"] * len(dataframe_headers), |
|
label="Individual Document Status & Extracted Entities", |
|
row_count=(0, "dynamic"), |
|
col_count=(len(dataframe_headers), "fixed"), |
|
wrap=True |
|
) |
|
|
|
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False) |
|
|
|
gr.Markdown("---") |
|
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.") |
|
|
|
|
|
process_button.click( |
|
fn=process_uploaded_files, |
|
inputs=[files_input], |
|
outputs=[ |
|
document_status_df, |
|
person_classification_output_md, |
|
ocr_json_output, |
|
overall_status_textbox |
|
] |
|
) |
|
|
|
@document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden") |
|
def display_selected_ocr(evt: gr.SelectData): |
|
if evt.index is None or evt.index[0] is None: |
|
return "{}" |
|
|
|
selected_row_index = evt.index[0] |
|
if selected_row_index < len(processed_files_data): |
|
selected_doc_data = processed_files_data[selected_row_index] |
|
if selected_doc_data and selected_doc_data["ocr_json"]: |
|
return json.dumps(selected_doc_data["ocr_json"], indent=2) |
|
return "{ \"message\": \"No OCR data found for selected row or selection out of bounds.\" }" |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(debug=True, share=True) |