Update app.py
Browse files
@@ -1,229 +1,205 @@
1 |
import gradio as gr
2 |
import pandas as pd
3 |
import numpy as np
4 |
import h5py
5 |
import json
6 |
7 |
8 |
import re
9 |
10 |
11 |
from sentence_transformers import SentenceTransformer
12 |
from nltk.corpus import stopwords
13 |
from nltk.tokenize import word_tokenize
14 |
import nltk
15 |
import torch
16 |
from sklearn.feature_extraction.text import CountVectorizer
17 |
18 |
19 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20 |
21 |
# Ensure you have downloaded the necessary NLTK data
22 |
nltk.download('stopwords', quiet=True)
23 |
nltk.download('punkt', quiet=True)
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
text =
39 |
text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE)
40 |
41 |
# Convert to lowercase while preserving acronyms and units
42 |
words = text.split()
43 |
text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words)
44 |
45 |
# Remove
46 |
text = re.sub(
47 |
text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) # Remove periods not in numbers
48 |
49 |
50 |
text = re.sub(r'\s
51 |
52 |
# Tokenize
53 |
tokens = word_tokenize(text)
54 |
55 |
# Remove stopwords
56 |
stop_words = set(stopwords.words('english'))
57 |
tokens = [word for word in tokens if word
58 |
59 |
# Join tokens back into
60 |
61 |
62 |
63 |
64 |
65 |
66 |
text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text)
67 |
text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text)
68 |
69 |
# Preserve chemical formulas
70 |
text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text)
71 |
72 |
return text
73 |
74 |
75 |
76 |
77 |
term_frequencies = np.sum(X.toarray(), axis=0)
78 |
document_frequencies = np.sum(X.toarray() > 0, axis=0)
79 |
num_documents = X.shape[0]
80 |
81 |
82 |
83 |
for term, doc_freq in zip(vectorizer.get_feature_names_out(), document_frequencies):
84 |
if doc_freq / num_documents > threshold:
85 |
86 |
removed_words[term] = doc_freq
87 |
88 |
89 |
90 |
filtered_text = ' '.join([word for word in text.split() if word not in common_terms])
91 |
92 |
93 |
94 |
95 |
def encode_texts(texts,
96 |
97 |
98 |
99 |
100 |
101 |
batch_texts = [str(text) for text in batch_texts]
102 |
batch_embeddings = model.encode(batch_texts, show_progress_bar=True)
103 |
104 |
progress((i // batch_size + 1) / total_batches, f"Processing batch {i // batch_size + 1}/{total_batches}")
105 |
106 |
embeddings = np.array(embeddings)
107 |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
108 |
return embeddings
109 |
110 |
111 |
112 |
113 |
114 |
115 |
df = pd.read_csv(file.name, encoding='utf-8')
116 |
logging.info(f"CSV file read successfully. Shape: {df.shape}")
117 |
118 |
required_columns = ['Master Patent Number', 'Abstract', 'Claims']
119 |
missing_columns = [col for col in required_columns if col not in df.columns]
120 |
if missing_columns:
121 |
return None, None, None, f"Error: Missing columns: {', '.join(missing_columns)}"
122 |
123 |
valid_texts = []
124 |
valid_patent_numbers = []
125 |
skipped_rows = []
126 |
error_rows = []
127 |
total_rows = len(df)
128 |
129 |
for index, row in df.iterrows():
130 |
131 |
progress((index + 1) / total_rows, f"Processing row {index + 1}/{total_rows}")
132 |
logging.info(f"Processing row {index + 1}/{total_rows}")
133 |
134 |
abstract = row['Abstract'] if pd.notna(row['Abstract']) else ''
135 |
claims = row['Claims'] if pd.notna(row['Claims']) else ''
136 |
137 |
if not abstract and not claims:
138 |
skipped_rows.append(row['Master Patent Number'])
139 |
140 |
141 |
# Preprocess the abstract and claims separately
142 |
preprocessed_abstract = preprocess_text(abstract)
143 |
preprocessed_claims = preprocess_text(claims)
144 |
145 |
# Combine preprocessed abstract and claims
146 |
combined_text = preprocessed_abstract + ' ' + preprocessed_claims
147 |
148 |
149 |
valid_patent_numbers.append(str(row['Master Patent Number']))
150 |
151 |
except Exception as e:
152 |
error_message = f"Error processing row {index + 1}: {str(e)}"
153 |
154 |
error_rows.append((index, row['Master Patent Number'], error_message))
155 |
156 |
157 |
logging.info(f"Preprocessed abstracts and claims. Number of valid texts: {len(valid_texts)}")
158 |
159 |
if skipped_rows:
160 |
logging.info(f"Skipped {len(skipped_rows)} rows due to missing abstract and claims.")
161 |
if error_rows:
162 |
logging.info(f"Encountered errors in {len(error_rows)} rows.")
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
for word, count in sorted(removed_words.items(), key=lambda x: x[1], reverse=True):
172 |
f.write(f"{word}: {count}\n")
173 |
174 |
175 |
176 |
177 |
178 |
# Save embeddings and metadata
179 |
embeddings_file = tempfile.NamedTemporaryFile(delete=False, suffix='.h5').name
180 |
with h5py.File(embeddings_file, 'w') as f:
181 |
f.create_dataset('embeddings', data=embeddings)
182 |
f.create_dataset('patent_numbers', data=valid_patent_numbers)
183 |
184 |
metadata_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl').name
185 |
with open(metadata_file, 'w', encoding='utf-8') as f:
186 |
for index, (patent_number, text) in enumerate(zip(valid_patent_numbers, filtered_texts)):
187 |
188 |
'index': index,
189 |
'patent_number': patent_number,
190 |
'text': text,
191 |
'embedding_index': index
192 |
}, f, ensure_ascii=False)
193 |
194 |
195 |
end_time = time.time()
196 |
total_time = end_time - start_time
197 |
logging.info(f"Processing completed in {total_time:.2f} seconds.")
198 |
199 |
# Save error log
200 |
error_log_file = 'error_log.txt'
201 |
with open(error_log_file, 'w', encoding='utf-8') as f:
202 |
for row in error_rows:
203 |
f.write(f"Row {row[0]}, Patent {row[1]}: {row[2]}\n")
204 |
205 |
return embeddings_file, metadata_file, removed_words_file, f"Processing complete. Encoded {len(filtered_texts)} patents. Skipped {len(skipped_rows)} patents due to missing data. Errors in {len(error_rows)} rows. See error_log.txt for details."
206 |
207 |
except Exception as e:
208 |
209 |
210 |
211 |
212 |
213 |
iface = gr.Interface(
214 |
215 |
216 |
217 |
218 |
gr.File(label="Patent Metadata (JSONL)"),
219 |
gr.File(label="Removed Words List (TXT)"),
220 |
gr.Textbox(label="Processing Status")
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
if __name__ == "__main__":
229 |
1 |
import gradio as gr
2 |
import numpy as np
3 |
import h5py
4 |
import faiss
5 |
import json
6 |
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
7 |
from sklearn.feature_extraction.text import TfidfVectorizer
8 |
from sklearn.metrics.pairwise import cosine_similarity
9 |
import re
10 |
from collections import Counter
11 |
import torch
12 |
from nltk.corpus import stopwords
13 |
from nltk.tokenize import word_tokenize
14 |
import nltk
15 |
16 |
# Download necessary NLTK data
17 |
nltk.download('stopwords', quiet=True)
18 |
nltk.download('punkt', quiet=True)
19 |
20 |
# Load BERT model for lemmatization
21 |
bert_lemma_model_name = "bert-base-uncased"
22 |
bert_lemma_tokenizer = AutoTokenizer.from_pretrained(bert_lemma_model_name)
23 |
bert_lemma_model = AutoModelForMaskedLM.from_pretrained(bert_lemma_model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
24 |
25 |
# Load BERT model for encoding search queries
26 |
bert_encode_model_name = 'anferico/bert-for-patents'
27 |
bert_encode_tokenizer = AutoTokenizer.from_pretrained(bert_encode_model_name)
28 |
bert_encode_model = AutoModel.from_pretrained(bert_encode_model_name)
29 |
30 |
def bert_lemmatize(text):
31 |
tokens = bert_lemma_tokenizer.tokenize(text)
32 |
input_ids = bert_lemma_tokenizer.convert_tokens_to_ids(tokens)
33 |
input_tensor = torch.tensor([input_ids]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
34 |
with torch.no_grad():
35 |
outputs = bert_lemma_model(input_tensor)
36 |
predictions = outputs.logits.argmax(dim=-1)
37 |
lemmatized_tokens = bert_lemma_tokenizer.convert_ids_to_tokens(predictions[0])
38 |
return ' '.join([token for token in lemmatized_tokens if token not in ['[CLS]', '[SEP]', '[PAD]']])
39 |
40 |
def preprocess_query(text):
41 |
# Convert to lowercase
42 |
text = text.lower()
43 |
44 |
# Remove any HTML tags (if present)
45 |
text = re.sub('<.*?>', '', text)
46 |
47 |
# Remove special characters, but keep hyphens, periods, and commas
48 |
text = re.sub(r'[^a-zA-Z0-9\s\-\.\,]', '', text)
49 |
50 |
# Tokenize
51 |
tokens = word_tokenize(text)
52 |
53 |
# Remove stopwords, but keep all other words
54 |
stop_words = set(stopwords.words('english'))
55 |
tokens = [word for word in tokens if word not in stop_words]
56 |
57 |
# Join tokens back into a string
58 |
processed_text = ' '.join(tokens)
59 |
60 |
# Apply BERT lemmatization
61 |
processed_text = bert_lemmatize(processed_text)
62 |
63 |
return processed_text
64 |
65 |
def extract_key_features(text):
66 |
# For queries, we'll just preprocess and return all non-stopword terms
67 |
processed_text = preprocess_query(text)
68 |
69 |
# Split the processed text into individual terms
70 |
features = processed_text.split()
71 |
72 |
# Remove duplicates while preserving order
73 |
features = list(dict.fromkeys(features))
74 |
75 |
return features
76 |
77 |
def encode_texts(texts, max_length=512):
78 |
inputs = bert_encode_tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
79 |
with torch.no_grad():
80 |
outputs = bert_encode_model(**inputs)
81 |
embeddings = outputs.last_hidden_state.mean(dim=1)
82 |
return embeddings.numpy()
83 |
84 |
def load_data():
85 |
86 |
with h5py.File('patent_embeddings.h5', 'r') as f:
87 |
embeddings = f['embeddings'][:]
88 |
patent_numbers = f['patent_numbers'][:]
89 |
90 |
metadata = {}
91 |
texts = []
92 |
with open('patent_metadata.jsonl', 'r') as f:
93 |
for line in f:
94 |
data = json.loads(line)
95 |
metadata[data['patent_number']] = data
96 |
97 |
98 |
print(f"Embedding shape: {embeddings.shape}")
99 |
print(f"Number of patent numbers: {len(patent_numbers)}")
100 |
print(f"Number of metadata entries: {len(metadata)}")
101 |
102 |
return embeddings, patent_numbers, metadata, texts
103 |
except FileNotFoundError as e:
104 |
print(f"Error: Could not find file. {e}")
105 |
106 |
except Exception as e:
107 |
print(f"An unexpected error occurred while loading data: {e}")
108 |
109 |
110 |
def compare_features(query_features, patent_features):
111 |
common_features = set(query_features) & set(patent_features)
112 |
similarity_score = len(common_features) / max(len(query_features), len(patent_features))
113 |
return common_features, similarity_score
114 |
115 |
def hybrid_search(query, top_k=5):
116 |
print(f"Original query: {query}")
117 |
118 |
processed_query = preprocess_query(query)
119 |
query_features = extract_key_features(processed_query)
120 |
121 |
# Encode the processed query using the transformer model
122 |
query_embedding = encode_texts([processed_query])[0]
123 |
query_embedding = query_embedding / np.linalg.norm(query_embedding)
124 |
125 |
# Perform semantic similarity search
126 |
semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
127 |
128 |
# Perform TF-IDF based search
129 |
query_tfidf = tfidf_vectorizer.transform([processed_query])
130 |
tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
131 |
tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
132 |
133 |
# Combine and rank results
134 |
combined_results = {}
135 |
for i, idx in enumerate(semantic_indices[0]):
136 |
patent_number = patent_numbers[idx].decode('utf-8')
137 |
text = metadata[patent_number]['text']
138 |
patent_features = extract_key_features(text)
139 |
common_features, feature_similarity = compare_features(query_features, patent_features)
140 |
combined_results[patent_number] = {
141 |
'score': semantic_distances[0][i] * 1.0 + tfidf_similarities[idx] * 0.5 + feature_similarity,
142 |
'common_features': common_features,
143 |
'text': text
144 |
145 |
146 |
for idx in tfidf_indices:
147 |
patent_number = patent_numbers[idx].decode('utf-8')
148 |
if patent_number not in combined_results:
149 |
text = metadata[patent_number]['text']
150 |
patent_features = extract_key_features(text)
151 |
common_features, feature_similarity = compare_features(query_features, patent_features)
152 |
combined_results[patent_number] = {
153 |
'score': tfidf_similarities[idx] * 1.0 + feature_similarity,
154 |
'common_features': common_features,
155 |
'text': text
156 |
157 |
158 |
# Sort and get top results
159 |
top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
160 |
161 |
results = []
162 |
for patent_number, data in top_results:
163 |
result = f"Patent Number: {patent_number}\n"
164 |
result += f"Text: {data['text'][:200]}...\n"
165 |
result += f"Combined Score: {data['score']:.4f}\n"
166 |
result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
167 |
168 |
169 |
return "\n".join(results)
170 |
171 |
# Load data and prepare the FAISS index
172 |
embeddings, patent_numbers, metadata, texts = load_data()
173 |
174 |
# Check if the embedding dimensions match
175 |
if embeddings.shape[1] != encode_texts(["test"]).shape[1]:
176 |
print("Embedding dimensions do not match. Rebuilding FAISS index.")
177 |
# Rebuild embeddings using the new model
178 |
embeddings = encode_texts(texts)
179 |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
180 |
181 |
# Normalize embeddings for cosine similarity
182 |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
183 |
184 |
# Create FAISS index for cosine similarity
185 |
index = faiss.IndexFlatIP(embeddings.shape[1])
186 |
187 |
188 |
# Create TF-IDF vectorizer
189 |
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
190 |
tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
191 |
192 |
# Create Gradio interface with additional input fields
193 |
iface = gr.Interface(
194 |
195 |
196 |
gr.Textbox(lines=2, placeholder="Enter your patent query here..."),
197 |
gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Top K Results"),
198 |
199 |
outputs=gr.Textbox(lines=10, label="Search Results"),
200 |
title="Patent Similarity Search",
201 |
description="Enter a patent description to find similar patents based on key features."
202 |
203 |
204 |
if __name__ == "__main__":
205 |