import gradio as gr import pandas as pd import numpy as np import h5py import json import os import tempfile import re import time import logging from sentence_transformers import SentenceTransformer from nltk.corpus import stopwords from nltk.tokenize import word_tokenize import nltk import torch from sklearn.feature_extraction.text import CountVectorizer # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Ensure you have downloaded the necessary NLTK data nltk.download('stopwords', quiet=True) nltk.download('punkt', quiet=True) # Disable tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" # Check for GPU availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load pre-trained model from Hugging Face logging.info("Loading SentenceTransformer model...") model = SentenceTransformer('anferico/bert-for-patents').to(device) logging.info("SentenceTransformer model loaded successfully.") def preprocess_text(text): # Remove "[EN]" label and claim numbers text = re.sub(r'\[EN\]\s*', '', text) text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE) # Convert to lowercase while preserving acronyms and units words = text.split() text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words) # Remove special characters except hyphens and periods in numbers text = re.sub(r'[^\w\s\-.]', ' ', text) text = re.sub(r'(? 0, axis=0) num_documents = X.shape[0] common_terms = set() removed_words = {} for term, doc_freq in zip(vectorizer.get_feature_names_out(), document_frequencies): if doc_freq / num_documents > threshold: common_terms.add(term) removed_words[term] = doc_freq filtered_texts = [] for text in texts: filtered_text = ' '.join([word for word in text.split() if word not in common_terms]) filtered_texts.append(filtered_text) return filtered_texts, removed_words def encode_texts(texts, progress=gr.Progress(), batch_size=64): embeddings = [] total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0) for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] batch_texts = [str(text) for text in batch_texts] batch_embeddings = model.encode(batch_texts, show_progress_bar=True) embeddings.extend(batch_embeddings) progress((i // batch_size + 1) / total_batches, f"Processing batch {i // batch_size + 1}/{total_batches}") embeddings = np.array(embeddings) embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) return embeddings def process_file(file, progress=gr.Progress()): try: start_time = time.time() # Read CSV file df = pd.read_csv(file.name, encoding='utf-8') logging.info(f"CSV file read successfully. Shape: {df.shape}") required_columns = ['Master Patent Number', 'Abstract', 'Claims'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: return None, None, None, f"Error: Missing columns: {', '.join(missing_columns)}" valid_texts = [] valid_patent_numbers = [] skipped_rows = [] error_rows = [] total_rows = len(df) for index, row in df.iterrows(): try: progress((index + 1) / total_rows, f"Processing row {index + 1}/{total_rows}") logging.info(f"Processing row {index + 1}/{total_rows}") abstract = row['Abstract'] if pd.notna(row['Abstract']) else '' claims = row['Claims'] if pd.notna(row['Claims']) else '' if not abstract and not claims: skipped_rows.append(row['Master Patent Number']) continue # Preprocess the abstract and claims separately preprocessed_abstract = preprocess_text(abstract) preprocessed_claims = preprocess_text(claims) # Combine preprocessed abstract and claims combined_text = preprocessed_abstract + ' ' + preprocessed_claims valid_texts.append(combined_text) valid_patent_numbers.append(str(row['Master Patent Number'])) except Exception as e: error_message = f"Error processing row {index + 1}: {str(e)}" logging.error(error_message) error_rows.append((index, row['Master Patent Number'], error_message)) continue logging.info(f"Preprocessed abstracts and claims. Number of valid texts: {len(valid_texts)}") if skipped_rows: logging.info(f"Skipped {len(skipped_rows)} rows due to missing abstract and claims.") if error_rows: logging.info(f"Encountered errors in {len(error_rows)} rows.") # Filter out common terms logging.info("Filtering common terms...") filtered_texts, removed_words = filter_common_terms(valid_texts, threshold=0.10) # Generate removed words file removed_words_file = 'removed_words.txt' with open(removed_words_file, 'w', encoding='utf-8') as f: for word, count in sorted(removed_words.items(), key=lambda x: x[1], reverse=True): f.write(f"{word}: {count}\n") logging.info("Encoding texts...") embeddings = encode_texts(filtered_texts, progress) logging.info("Texts encoded successfully.") # Save embeddings and metadata embeddings_file = tempfile.NamedTemporaryFile(delete=False, suffix='.h5').name with h5py.File(embeddings_file, 'w') as f: f.create_dataset('embeddings', data=embeddings) f.create_dataset('patent_numbers', data=valid_patent_numbers) metadata_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl').name with open(metadata_file, 'w', encoding='utf-8') as f: for index, (patent_number, text) in enumerate(zip(valid_patent_numbers, filtered_texts)): json.dump({ 'index': index, 'patent_number': patent_number, 'text': text, 'embedding_index': index }, f, ensure_ascii=False) f.write('\n') end_time = time.time() total_time = end_time - start_time logging.info(f"Processing completed in {total_time:.2f} seconds.") # Save error log error_log_file = 'error_log.txt' with open(error_log_file, 'w', encoding='utf-8') as f: for row in error_rows: f.write(f"Row {row[0]}, Patent {row[1]}: {row[2]}\n") return embeddings_file, metadata_file, removed_words_file, f"Processing complete. Encoded {len(filtered_texts)} patents. Skipped {len(skipped_rows)} patents due to missing data. Errors in {len(error_rows)} rows. See error_log.txt for details." except Exception as e: logging.error(f"An error occurred: {e}") import traceback traceback.print_exc() return None, None, None, f"An error occurred: {str(e)}" iface = gr.Interface( fn=process_file, inputs=gr.File(label="Upload a CSV file with patent data"), outputs=[ gr.File(label="Patent Embeddings (HDF5)"), gr.File(label="Patent Metadata (JSONL)"), gr.File(label="Removed Words List (TXT)"), gr.Textbox(label="Processing Status") ], title="Patent Text Encoder", description="Upload a CSV file containing patent data (must include 'Master Patent Number', 'Abstract', and 'Claims' columns). The app will generate embeddings and save them along with metadata as downloadable files.", allow_flagging="never", cache_examples=False, ) if __name__ == "__main__": iface.launch()