bhlewis commited on
Commit
61b33b0
1 Parent(s): bbb0782

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -162
app.py CHANGED
@@ -1,204 +1,228 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import h5py
4
- import faiss
5
  import json
6
- from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
7
- from sklearn.feature_extraction.text import TfidfVectorizer
8
- from sklearn.metrics.pairwise import cosine_similarity
9
  import re
10
- from collections import Counter
11
- import torch
 
12
  from nltk.corpus import stopwords
13
  from nltk.tokenize import word_tokenize
14
  import nltk
 
 
15
 
16
- # Download necessary NLTK data
 
 
 
17
  nltk.download('stopwords', quiet=True)
18
  nltk.download('punkt', quiet=True)
19
 
20
- # Load BERT model for lemmatization
21
- bert_lemma_model_name = "bert-base-uncased"
22
- bert_lemma_tokenizer = AutoTokenizer.from_pretrained(bert_lemma_model_name)
23
- bert_lemma_model = AutoModelForMaskedLM.from_pretrained(bert_lemma_model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
24
 
25
- # Load BERT model for encoding search queries
26
- bert_encode_model_name = 'anferico/bert-for-patents'
27
- bert_encode_tokenizer = AutoTokenizer.from_pretrained(bert_encode_model_name)
28
- bert_encode_model = AutoModel.from_pretrained(bert_encode_model_name)
29
 
30
- def bert_lemmatize(text):
31
- tokens = bert_lemma_tokenizer.tokenize(text)
32
- input_ids = bert_lemma_tokenizer.convert_tokens_to_ids(tokens)
33
- input_tensor = torch.tensor([input_ids]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
34
- with torch.no_grad():
35
- outputs = bert_lemma_model(input_tensor)
36
- predictions = outputs.logits.argmax(dim=-1)
37
- lemmatized_tokens = bert_lemma_tokenizer.convert_ids_to_tokens(predictions[0])
38
- return ' '.join([token for token in lemmatized_tokens if token not in ['[CLS]', '[SEP]', '[PAD]']])
39
 
40
- def preprocess_query(text):
41
- # Convert to lowercase
42
- text = text.lower()
 
 
 
 
 
43
 
44
- # Remove any HTML tags (if present)
45
- text = re.sub('<.*?>', '', text)
 
46
 
47
- # Remove special characters, but keep hyphens, periods, and commas
48
- text = re.sub(r'[^a-zA-Z0-9\s\-\.\,]', '', text)
49
 
50
  # Tokenize
51
  tokens = word_tokenize(text)
52
 
53
- # Remove stopwords, but keep all other words
54
  stop_words = set(stopwords.words('english'))
55
- tokens = [word for word in tokens if word not in stop_words]
56
 
57
- # Join tokens back into a string
58
- processed_text = ' '.join(tokens)
59
 
60
- # Apply BERT lemmatization
61
- processed_text = bert_lemmatize(processed_text)
62
 
63
- return processed_text
 
 
 
 
 
 
 
64
 
65
- def extract_key_features(text):
66
- # For queries, we'll just preprocess and return all non-stopword terms
67
- processed_text = preprocess_query(text)
 
 
 
68
 
69
- # Split the processed text into individual terms
70
- features = processed_text.split()
 
 
 
 
71
 
72
- # Remove duplicates while preserving order
73
- features = list(dict.fromkeys(features))
 
 
74
 
75
- return features
76
 
77
- def encode_texts(texts, max_length=512):
78
- inputs = bert_encode_tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
79
- with torch.no_grad():
80
- outputs = bert_encode_model(**inputs)
81
- embeddings = outputs.last_hidden_state.mean(dim=1)
82
- return embeddings.numpy()
 
 
 
 
 
 
 
 
83
 
84
- def load_data():
85
  try:
86
- with h5py.File('patent_embeddings.h5', 'r') as f:
87
- embeddings = f['embeddings'][:]
88
- patent_numbers = f['patent_numbers'][:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- metadata = {}
91
- texts = []
92
- with open('patent_metadata.jsonl', 'r') as f:
93
- for line in f:
94
- data = json.loads(line)
95
- metadata[data['patent_number']] = data
96
- texts.append(data['text'])
97
 
98
- print(f"Embedding shape: {embeddings.shape}")
99
- print(f"Number of patent numbers: {len(patent_numbers)}")
100
- print(f"Number of metadata entries: {len(metadata)}")
 
 
101
 
102
- return embeddings, patent_numbers, metadata, texts
103
- except FileNotFoundError as e:
104
- print(f"Error: Could not find file. {e}")
105
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
- print(f"An unexpected error occurred while loading data: {e}")
108
- raise
109
-
110
- def compare_features(query_features, patent_features):
111
- common_features = set(query_features) & set(patent_features)
112
- similarity_score = len(common_features) / max(len(query_features), len(patent_features))
113
- return common_features, similarity_score
114
-
115
- def hybrid_search(query, top_k=5):
116
- print(f"Original query: {query}")
117
-
118
- processed_query = preprocess_query(query)
119
- query_features = extract_key_features(processed_query)
120
-
121
- # Encode the processed query using the transformer model
122
- query_embedding = encode_texts([processed_query])[0]
123
- query_embedding = query_embedding / np.linalg.norm(query_embedding)
124
-
125
- # Perform semantic similarity search
126
- semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
127
-
128
- # Perform TF-IDF based search
129
- query_tfidf = tfidf_vectorizer.transform([processed_query])
130
- tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
131
- tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
132
-
133
- # Combine and rank results
134
- combined_results = {}
135
- for i, idx in enumerate(semantic_indices[0]):
136
- patent_number = patent_numbers[idx].decode('utf-8')
137
- text = metadata[patent_number]['text']
138
- patent_features = extract_key_features(text)
139
- common_features, feature_similarity = compare_features(query_features, patent_features)
140
- combined_results[patent_number] = {
141
- 'score': semantic_distances[0][i] * 1.0 + tfidf_similarities[idx] * 0.5 + feature_similarity,
142
- 'common_features': common_features,
143
- 'text': text
144
- }
145
-
146
- for idx in tfidf_indices:
147
- patent_number = patent_numbers[idx].decode('utf-8')
148
- if patent_number not in combined_results:
149
- text = metadata[patent_number]['text']
150
- patent_features = extract_key_features(text)
151
- common_features, feature_similarity = compare_features(query_features, patent_features)
152
- combined_results[patent_number] = {
153
- 'score': tfidf_similarities[idx] * 1.0 + feature_similarity,
154
- 'common_features': common_features,
155
- 'text': text
156
- }
157
-
158
- # Sort and get top results
159
- top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
160
-
161
- results = []
162
- for patent_number, data in top_results:
163
- result = f"Patent Number: {patent_number}\n"
164
- result += f"Text: {data['text'][:200]}...\n"
165
- result += f"Combined Score: {data['score']:.4f}\n"
166
- result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
167
- results.append(result)
168
-
169
- return "\n".join(results)
170
-
171
- # Load data and prepare the FAISS index
172
- embeddings, patent_numbers, metadata, texts = load_data()
173
-
174
- # Check if the embedding dimensions match
175
- if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
176
- print("Embedding dimensions do not match. Rebuilding FAISS index.")
177
- # Rebuild embeddings using the new model
178
- embeddings = encode_texts(texts)
179
- embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
180
-
181
- # Normalize embeddings for cosine similarity
182
- embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
183
-
184
- # Create FAISS index for cosine similarity
185
- index = faiss.IndexFlatIP(embeddings.shape[1])
186
- index.add(embeddings)
187
-
188
- # Create TF-IDF vectorizer
189
- tfidf_vectorizer = TfidfVectorizer(stop_words='english')
190
- tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
191
 
192
- # Create Gradio interface with additional input fields
193
  iface = gr.Interface(
194
- fn=hybrid_search,
195
- inputs=[
196
- gr.Textbox(lines=2, placeholder="Enter your patent query here..."),
197
- gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Top K Results"),
 
 
 
198
  ],
199
- outputs=gr.Textbox(lines=10, label="Search Results"),
200
- title="Patent Similarity Search",
201
- description="Enter a patent description to find similar patents based on key features."
 
202
  )
203
 
204
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import numpy as np
4
  import h5py
 
5
  import json
6
+ import os
7
+ import tempfile
 
8
  import re
9
+ import time
10
+ import logging
11
+ from sentence_transformers import SentenceTransformer
12
  from nltk.corpus import stopwords
13
  from nltk.tokenize import word_tokenize
14
  import nltk
15
+ import torch
16
+ from sklearn.feature_extraction.text import CountVectorizer
17
 
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+
21
+ # Ensure you have downloaded the necessary NLTK data
22
  nltk.download('stopwords', quiet=True)
23
  nltk.download('punkt', quiet=True)
24
 
25
+ # Disable tokenizer parallelism warning
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
27
 
28
+ # Check for GPU availability
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
30
 
31
+ # Load pre-trained model from Hugging Face
32
+ logging.info("Loading SentenceTransformer model...")
33
+ model = SentenceTransformer('anferico/bert-for-patents').to(device)
34
+ logging.info("SentenceTransformer model loaded successfully.")
 
 
 
 
 
35
 
36
+ def preprocess_text(text):
37
+ # Remove "[EN]" label and claim numbers
38
+ text = re.sub(r'\[EN\]\s*', '', text)
39
+ text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE)
40
+
41
+ # Convert to lowercase while preserving acronyms and units
42
+ words = text.split()
43
+ text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words)
44
 
45
+ # Remove special characters except hyphens and periods in numbers
46
+ text = re.sub(r'[^\w\s\-.]', ' ', text)
47
+ text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) # Remove periods not in numbers
48
 
49
+ # Normalize spaces
50
+ text = re.sub(r'\s+', ' ', text).strip()
51
 
52
  # Tokenize
53
  tokens = word_tokenize(text)
54
 
55
+ # Remove stopwords
56
  stop_words = set(stopwords.words('english'))
57
+ tokens = [word for word in tokens if word.lower() not in stop_words]
58
 
59
+ # Join tokens back into text
60
+ text = ' '.join(tokens)
61
 
62
+ # Preserve numerical values with units
63
+ text = re.sub(r'(\d+(\.\d+)?)([a-zA-Z]+)', r'\1_\3', text)
64
 
65
+ # Handle ranges and measurements
66
+ text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text)
67
+ text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text)
68
+
69
+ # Preserve chemical formulas
70
+ text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text)
71
+
72
+ return text
73
 
74
+ def filter_common_terms(texts, threshold=0.10):
75
+ vectorizer = CountVectorizer()
76
+ X = vectorizer.fit_transform(texts)
77
+ term_frequencies = np.sum(X.toarray(), axis=0)
78
+ document_frequencies = np.sum(X.toarray() > 0, axis=0)
79
+ num_documents = X.shape[0]
80
 
81
+ common_terms = set()
82
+ removed_words = {}
83
+ for term, doc_freq in zip(vectorizer.get_feature_names_out(), document_frequencies):
84
+ if doc_freq / num_documents > threshold:
85
+ common_terms.add(term)
86
+ removed_words[term] = doc_freq
87
 
88
+ filtered_texts = []
89
+ for text in texts:
90
+ filtered_text = ' '.join([word for word in text.split() if word not in common_terms])
91
+ filtered_texts.append(filtered_text)
92
 
93
+ return filtered_texts, removed_words
94
 
95
+ def encode_texts(texts, progress=gr.Progress(), batch_size=64):
96
+ embeddings = []
97
+ total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0)
98
+
99
+ for i in range(0, len(texts), batch_size):
100
+ batch_texts = texts[i:i+batch_size]
101
+ batch_texts = [str(text) for text in batch_texts]
102
+ batch_embeddings = model.encode(batch_texts, show_progress_bar=True)
103
+ embeddings.extend(batch_embeddings)
104
+ progress((i // batch_size + 1) / total_batches, f"Processing batch {i // batch_size + 1}/{total_batches}")
105
+
106
+ embeddings = np.array(embeddings)
107
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
108
+ return embeddings
109
 
110
+ def process_file(file, progress=gr.Progress()):
111
  try:
112
+ start_time = time.time()
113
+
114
+ # Read CSV file
115
+ df = pd.read_csv(file.name, encoding='utf-8')
116
+ logging.info(f"CSV file read successfully. Shape: {df.shape}")
117
+
118
+ required_columns = ['Master Patent Number', 'Abstract', 'Claims']
119
+ missing_columns = [col for col in required_columns if col not in df.columns]
120
+ if missing_columns:
121
+ return None, None, None, f"Error: Missing columns: {', '.join(missing_columns)}"
122
+
123
+ valid_texts = []
124
+ valid_patent_numbers = []
125
+ skipped_rows = []
126
+ error_rows = []
127
+ total_rows = len(df)
128
+
129
+ for index, row in df.iterrows():
130
+ try:
131
+ progress((index + 1) / total_rows, f"Processing row {index + 1}/{total_rows}")
132
+ logging.info(f"Processing row {index + 1}/{total_rows}")
133
+
134
+ abstract = row['Abstract'] if pd.notna(row['Abstract']) else ''
135
+ claims = row['Claims'] if pd.notna(row['Claims']) else ''
136
+
137
+ if not abstract and not claims:
138
+ skipped_rows.append(row['Master Patent Number'])
139
+ continue
140
+
141
+ # Preprocess the abstract and claims separately
142
+ preprocessed_abstract = preprocess_text(abstract)
143
+ preprocessed_claims = preprocess_text(claims)
144
+
145
+ # Combine preprocessed abstract and claims
146
+ combined_text = preprocessed_abstract + ' ' + preprocessed_claims
147
+
148
+ valid_texts.append(combined_text)
149
+ valid_patent_numbers.append(str(row['Master Patent Number']))
150
+
151
+ except Exception as e:
152
+ error_message = f"Error processing row {index + 1}: {str(e)}"
153
+ logging.error(error_message)
154
+ error_rows.append((index, row['Master Patent Number'], error_message))
155
+ continue
156
+
157
+ logging.info(f"Preprocessed abstracts and claims. Number of valid texts: {len(valid_texts)}")
158
+
159
+ if skipped_rows:
160
+ logging.info(f"Skipped {len(skipped_rows)} rows due to missing abstract and claims.")
161
+ if error_rows:
162
+ logging.info(f"Encountered errors in {len(error_rows)} rows.")
163
 
164
+ # Filter out common terms
165
+ logging.info("Filtering common terms...")
166
+ filtered_texts, removed_words = filter_common_terms(valid_texts, threshold=0.10)
 
 
 
 
167
 
168
+ # Generate removed words file
169
+ removed_words_file = 'removed_words.txt'
170
+ with open(removed_words_file, 'w', encoding='utf-8') as f:
171
+ for word, count in sorted(removed_words.items(), key=lambda x: x[1], reverse=True):
172
+ f.write(f"{word}: {count}\n")
173
 
174
+ logging.info("Encoding texts...")
175
+ embeddings = encode_texts(filtered_texts, progress)
176
+ logging.info("Texts encoded successfully.")
177
+
178
+ # Save embeddings and metadata
179
+ embeddings_file = tempfile.NamedTemporaryFile(delete=False, suffix='.h5').name
180
+ with h5py.File(embeddings_file, 'w') as f:
181
+ f.create_dataset('embeddings', data=embeddings)
182
+ f.create_dataset('patent_numbers', data=valid_patent_numbers)
183
+
184
+ metadata_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl').name
185
+ with open(metadata_file, 'w', encoding='utf-8') as f:
186
+ for index, (patent_number, text) in enumerate(zip(valid_patent_numbers, filtered_texts)):
187
+ json.dump({
188
+ 'index': index,
189
+ 'patent_number': patent_number,
190
+ 'text': text,
191
+ 'embedding_index': index
192
+ }, f, ensure_ascii=False)
193
+ f.write('\n')
194
+
195
+ end_time = time.time()
196
+ total_time = end_time - start_time
197
+ logging.info(f"Processing completed in {total_time:.2f} seconds.")
198
+
199
+ # Save error log
200
+ error_log_file = 'error_log.txt'
201
+ with open(error_log_file, 'w', encoding='utf-8') as f:
202
+ for row in error_rows:
203
+ f.write(f"Row {row[0]}, Patent {row[1]}: {row[2]}\n")
204
+
205
+ return embeddings_file, metadata_file, removed_words_file, f"Processing complete. Encoded {len(filtered_texts)} patents. Skipped {len(skipped_rows)} patents due to missing data. Errors in {len(error_rows)} rows. See error_log.txt for details."
206
+
207
  except Exception as e:
208
+ logging.error(f"An error occurred: {e}")
209
+ import traceback
210
+ traceback.print_exc()
211
+ return None, None, None, f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
 
213
  iface = gr.Interface(
214
+ fn=process_file,
215
+ inputs=gr.File(label="Upload a CSV file with patent data"),
216
+ outputs=[
217
+ gr.File(label="Patent Embeddings (HDF5)"),
218
+ gr.File(label="Patent Metadata (JSONL)"),
219
+ gr.File(label="Removed Words List (TXT)"),
220
+ gr.Textbox(label="Processing Status")
221
  ],
222
+ title="Patent Text Encoder",
223
+ description="Upload a CSV file containing patent data (must include 'Master Patent Number', 'Abstract', and 'Claims' columns). The app will generate embeddings and save them along with metadata as downloadable files.",
224
+ allow_flagging="never",
225
+ cache_examples=False,
226
  )
227
 
228
  if __name__ == "__main__":