|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import h5py |
|
import json |
|
import os |
|
import tempfile |
|
import re |
|
import time |
|
import logging |
|
from sentence_transformers import SentenceTransformer |
|
from nltk.corpus import stopwords |
|
from nltk.tokenize import word_tokenize |
|
import nltk |
|
import torch |
|
from sklearn.feature_extraction.text import CountVectorizer |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
nltk.download('stopwords', quiet=True) |
|
nltk.download('punkt', quiet=True) |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
logging.info("Loading SentenceTransformer model...") |
|
model = SentenceTransformer('anferico/bert-for-patents').to(device) |
|
logging.info("SentenceTransformer model loaded successfully.") |
|
|
|
def preprocess_text(text): |
|
|
|
text = re.sub(r'\[EN\]\s*', '', text) |
|
text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE) |
|
|
|
|
|
words = text.split() |
|
text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words) |
|
|
|
|
|
text = re.sub(r'[^\w\s\-.]', ' ', text) |
|
text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip() |
|
|
|
|
|
tokens = word_tokenize(text) |
|
|
|
|
|
stop_words = set(stopwords.words('english')) |
|
tokens = [word for word in tokens if word.lower() not in stop_words] |
|
|
|
|
|
text = ' '.join(tokens) |
|
|
|
|
|
text = re.sub(r'(\d+(\.\d+)?)([a-zA-Z]+)', r'\1_\3', text) |
|
|
|
|
|
text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text) |
|
text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text) |
|
|
|
|
|
text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text) |
|
|
|
return text |
|
|
|
def filter_common_terms(texts, threshold=0.10): |
|
vectorizer = CountVectorizer() |
|
X = vectorizer.fit_transform(texts) |
|
term_frequencies = np.sum(X.toarray(), axis=0) |
|
document_frequencies = np.sum(X.toarray() > 0, axis=0) |
|
num_documents = X.shape[0] |
|
|
|
common_terms = set() |
|
removed_words = {} |
|
for term, doc_freq in zip(vectorizer.get_feature_names_out(), document_frequencies): |
|
if doc_freq / num_documents > threshold: |
|
common_terms.add(term) |
|
removed_words[term] = doc_freq |
|
|
|
filtered_texts = [] |
|
for text in texts: |
|
filtered_text = ' '.join([word for word in text.split() if word not in common_terms]) |
|
filtered_texts.append(filtered_text) |
|
|
|
return filtered_texts, removed_words |
|
|
|
def encode_texts(texts, progress=gr.Progress(), batch_size=64): |
|
embeddings = [] |
|
total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0) |
|
|
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i+batch_size] |
|
batch_texts = [str(text) for text in batch_texts] |
|
batch_embeddings = model.encode(batch_texts, show_progress_bar=True) |
|
embeddings.extend(batch_embeddings) |
|
progress((i // batch_size + 1) / total_batches, f"Processing batch {i // batch_size + 1}/{total_batches}") |
|
|
|
embeddings = np.array(embeddings) |
|
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
return embeddings |
|
|
|
def process_file(file, progress=gr.Progress()): |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
df = pd.read_csv(file.name, encoding='utf-8') |
|
logging.info(f"CSV file read successfully. Shape: {df.shape}") |
|
|
|
required_columns = ['Master Patent Number', 'Abstract', 'Claims'] |
|
missing_columns = [col for col in required_columns if col not in df.columns] |
|
if missing_columns: |
|
return None, None, None, f"Error: Missing columns: {', '.join(missing_columns)}" |
|
|
|
valid_texts = [] |
|
valid_patent_numbers = [] |
|
skipped_rows = [] |
|
error_rows = [] |
|
total_rows = len(df) |
|
|
|
for index, row in df.iterrows(): |
|
try: |
|
progress((index + 1) / total_rows, f"Processing row {index + 1}/{total_rows}") |
|
logging.info(f"Processing row {index + 1}/{total_rows}") |
|
|
|
abstract = row['Abstract'] if pd.notna(row['Abstract']) else '' |
|
claims = row['Claims'] if pd.notna(row['Claims']) else '' |
|
|
|
if not abstract and not claims: |
|
skipped_rows.append(row['Master Patent Number']) |
|
continue |
|
|
|
|
|
preprocessed_abstract = preprocess_text(abstract) |
|
preprocessed_claims = preprocess_text(claims) |
|
|
|
|
|
combined_text = preprocessed_abstract + ' ' + preprocessed_claims |
|
|
|
valid_texts.append(combined_text) |
|
valid_patent_numbers.append(str(row['Master Patent Number'])) |
|
|
|
except Exception as e: |
|
error_message = f"Error processing row {index + 1}: {str(e)}" |
|
logging.error(error_message) |
|
error_rows.append((index, row['Master Patent Number'], error_message)) |
|
continue |
|
|
|
logging.info(f"Preprocessed abstracts and claims. Number of valid texts: {len(valid_texts)}") |
|
|
|
if skipped_rows: |
|
logging.info(f"Skipped {len(skipped_rows)} rows due to missing abstract and claims.") |
|
if error_rows: |
|
logging.info(f"Encountered errors in {len(error_rows)} rows.") |
|
|
|
|
|
logging.info("Filtering common terms...") |
|
filtered_texts, removed_words = filter_common_terms(valid_texts, threshold=0.10) |
|
|
|
|
|
removed_words_file = 'removed_words.txt' |
|
with open(removed_words_file, 'w', encoding='utf-8') as f: |
|
for word, count in sorted(removed_words.items(), key=lambda x: x[1], reverse=True): |
|
f.write(f"{word}: {count}\n") |
|
|
|
logging.info("Encoding texts...") |
|
embeddings = encode_texts(filtered_texts, progress) |
|
logging.info("Texts encoded successfully.") |
|
|
|
|
|
embeddings_file = tempfile.NamedTemporaryFile(delete=False, suffix='.h5').name |
|
with h5py.File(embeddings_file, 'w') as f: |
|
f.create_dataset('embeddings', data=embeddings) |
|
f.create_dataset('patent_numbers', data=valid_patent_numbers) |
|
|
|
metadata_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl').name |
|
with open(metadata_file, 'w', encoding='utf-8') as f: |
|
for index, (patent_number, text) in enumerate(zip(valid_patent_numbers, filtered_texts)): |
|
json.dump({ |
|
'index': index, |
|
'patent_number': patent_number, |
|
'text': text, |
|
'embedding_index': index |
|
}, f, ensure_ascii=False) |
|
f.write('\n') |
|
|
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
logging.info(f"Processing completed in {total_time:.2f} seconds.") |
|
|
|
|
|
error_log_file = 'error_log.txt' |
|
with open(error_log_file, 'w', encoding='utf-8') as f: |
|
for row in error_rows: |
|
f.write(f"Row {row[0]}, Patent {row[1]}: {row[2]}\n") |
|
|
|
return embeddings_file, metadata_file, removed_words_file, f"Processing complete. Encoded {len(filtered_texts)} patents. Skipped {len(skipped_rows)} patents due to missing data. Errors in {len(error_rows)} rows. See error_log.txt for details." |
|
|
|
except Exception as e: |
|
logging.error(f"An error occurred: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None, None, None, f"An error occurred: {str(e)}" |
|
|
|
iface = gr.Interface( |
|
fn=process_file, |
|
inputs=gr.File(label="Upload a CSV file with patent data"), |
|
outputs=[ |
|
gr.File(label="Patent Embeddings (HDF5)"), |
|
gr.File(label="Patent Metadata (JSONL)"), |
|
gr.File(label="Removed Words List (TXT)"), |
|
gr.Textbox(label="Processing Status") |
|
], |
|
title="Patent Text Encoder", |
|
description="Upload a CSV file containing patent data (must include 'Master Patent Number', 'Abstract', and 'Claims' columns). The app will generate embeddings and save them along with metadata as downloadable files.", |
|
allow_flagging="never", |
|
cache_examples=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|