bhlewis's picture
Update app.py
a1d94cc verified
raw
history blame
6.95 kB
import gradio as gr
import numpy as np
import h5py
import faiss
import json
import re
from collections import Counter
import torch
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import nltk
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# Download necessary NLTK data
nltk.download('stopwords', quiet=True)
nltk.download('punkt', quiet=True)
# Load SentenceTransformer model
model = SentenceTransformer('anferico/bert-for-patents')
def preprocess_query(text):
# Remove "[EN]" label and claim numbers
text = re.sub(r'\[EN\]\s*', '', text)
text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE)
# Convert to lowercase while preserving acronyms and units
words = text.split()
text = ' '.join(word if word.isupper() or re.match(r'^\d+(\.\d+)?[a-zA-Z]+$', word) else word.lower() for word in words)
# Remove special characters except hyphens and periods in numbers
text = re.sub(r'[^\w\s\-.]', ' ', text)
text = re.sub(r'(?<!\d)\.(?!\d)', ' ', text) # Remove periods not in numbers
# Normalize spaces
text = re.sub(r'\s+', ' ', text).strip()
# Tokenize
tokens = word_tokenize(text)
# Remove stopwords
stop_words = set(stopwords.words('english'))
tokens = [word for word in tokens if word.lower() not in stop_words]
# Join tokens back into text
text = ' '.join(tokens)
# Preserve numerical values with units
text = re.sub(r'(\d+(\.\d+)?)([a-zA-Z]+)', r'\1_\3', text)
# Handle ranges and measurements
text = re.sub(r'(\d+(\.\d+)?)(\s*to\s*)(\d+(\.\d+)?)(\s*[a-zA-Z]+)', r'\1_to_\4_\6', text)
text = re.sub(r'between\s*(\d+(\.\d+)?)(\s*and\s*)(\d+(\.\d+)?)\s*([a-zA-Z]+)', r'between_\1_and_\4_\5', text)
# Preserve chemical formulas
text = re.sub(r'\b([A-Z][a-z]?\d*)+\b', lambda m: m.group().replace(' ', ''), text)
return text
def extract_key_features(text):
# For queries, we'll just preprocess and return all non-stopword terms
processed_text = preprocess_query(text)
# Split the processed text into individual terms
features = processed_text.split()
# Remove duplicates while preserving order
features = list(dict.fromkeys(features))
return features
def encode_texts(texts):
embeddings = model.encode(texts, show_progress_bar=True)
return embeddings
def load_data():
try:
with h5py.File('patent_embeddings.h5', 'r') as f:
embeddings = f['embeddings'][:]
patent_numbers = f['patent_numbers'][:]
metadata = {}
texts = []
with open('patent_metadata.jsonl', 'r') as f:
for line in f:
data = json.loads(line)
metadata[data['patent_number']] = data
texts.append(data['text'])
print(f"Embedding shape: {embeddings.shape}")
print(f"Number of patent numbers: {len(patent_numbers)}")
print(f"Number of metadata entries: {len(metadata)}")
return embeddings, patent_numbers, metadata, texts
except Exception as e:
print(f"An error occurred while loading data: {e}")
raise
def compare_features(query_features, patent_features):
common_features = set(query_features) & set(patent_features)
similarity_score = len(common_features) / max(len(query_features), len(patent_features))
return common_features, similarity_score
def hybrid_search(query, top_k=5):
print(f"Original query: {query}")
processed_query = preprocess_query(query)
query_features = extract_key_features(processed_query)
# Encode the processed query using the SentenceTransformer model
query_embedding = encode_texts([processed_query])[0]
query_embedding = query_embedding / np.linalg.norm(query_embedding)
# Perform semantic similarity search
semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
# Perform TF-IDF based search
query_tfidf = tfidf_vectorizer.transform([processed_query])
tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
# Combine and rank results
combined_results = {}
for i, idx in enumerate(semantic_indices[0]):
patent_number = patent_numbers[idx].decode('utf-8')
text = metadata[patent_number]['text']
patent_features = extract_key_features(text)
common_features, feature_similarity = compare_features(query_features, patent_features)
combined_results[patent_number] = {
'score': semantic_distances[0][i] * 1.0 + tfidf_similarities[idx] * 0.5 + feature_similarity,
'common_features': common_features,
'text': text
}
for idx in tfidf_indices:
patent_number = patent_numbers[idx].decode('utf-8')
if patent_number not in combined_results:
text = metadata[patent_number]['text']
patent_features = extract_key_features(text)
common_features, feature_similarity = compare_features(query_features, patent_features)
combined_results[patent_number] = {
'score': tfidf_similarities[idx] * 1.0 + feature_similarity,
'common_features': common_features,
'text': text
}
# Sort and get top results
top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
results = []
for patent_number, data in top_results:
result = f"Patent Number: {patent_number}\n"
result += f"Text: {data['text'][:200]}...\n"
result += f"Combined Score: {data['score']:.4f}\n"
result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
results.append(result)
return "\n".join(results)
# Load data and prepare the FAISS index
embeddings, patent_numbers, metadata, texts = load_data()
# Normalize embeddings for cosine similarity
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Create FAISS index for cosine similarity
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
# Create TF-IDF vectorizer
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
# Create Gradio interface
iface = gr.Interface(
fn=hybrid_search,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your patent query here..."),
gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Top K Results"),
],
outputs=gr.Textbox(lines=10, label="Search Results"),
title="Patent Similarity Search",
description="Enter a patent description to find similar patents based on key features."
)
if __name__ == "__main__":
iface.launch()