File size: 3,899 Bytes
663d1ad 5c704be 663d1ad 6744e1a 5c704be 6744e1a 5c704be 6744e1a 5c704be 6744e1a 663d1ad 5c704be 663d1ad 4a2057c 663d1ad 4a2057c 74523b8 5c704be 4a2057c 5c704be 663d1ad 4a2057c 663d1ad 5c704be 74523b8 5c704be 663d1ad 5c704be 086c46f 5c704be 13d437f 5c704be c498c82 663d1ad 13d437f 5c704be 663d1ad 5c704be 663d1ad 5c704be 663d1ad 5c704be 663d1ad 74523b8 13d437f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import numpy as np
import h5py
import faiss
import json
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
def load_data():
try:
with h5py.File('patent_embeddings.h5', 'r') as f:
embeddings = f['embeddings'][:]
patent_numbers = f['patent_numbers'][:]
metadata = {}
texts = []
with open('patent_metadata.jsonl', 'r') as f:
for line in f:
data = json.loads(line)
metadata[data['patent_number']] = data
texts.append(data['text'])
print(f"Embedding shape: {embeddings.shape}")
print(f"Number of patent numbers: {len(patent_numbers)}")
print(f"Number of metadata entries: {len(metadata)}")
return embeddings, patent_numbers, metadata, texts
except FileNotFoundError as e:
print(f"Error: Could not find file. {e}")
raise
except Exception as e:
print(f"An unexpected error occurred while loading data: {e}")
raise
embeddings, patent_numbers, metadata, texts = load_data()
# Normalize embeddings for cosine similarity
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Create FAISS index for cosine similarity
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
# Load BERT model for encoding search queries
model = SentenceTransformer('all-mpnet-base-v2')
# Create TF-IDF vectorizer
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
def hybrid_search(query, top_k=5):
print(f"Searching for: {query}")
# Encode the query using the transformer model
query_embedding = model.encode([query])[0]
query_embedding = query_embedding / np.linalg.norm(query_embedding)
# Perform semantic similarity search
semantic_distances, semantic_indices = index.search(np.array([query_embedding]), top_k * 2)
# Perform TF-IDF based search
query_tfidf = tfidf_vectorizer.transform([query])
tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
# Combine and rank results
combined_results = {}
for i, idx in enumerate(semantic_indices[0]):
patent_number = patent_numbers[idx].decode('utf-8')
combined_results[patent_number] = semantic_distances[0][i]
for idx in tfidf_indices:
patent_number = patent_numbers[idx].decode('utf-8')
if patent_number in combined_results:
combined_results[patent_number] += tfidf_similarities[idx]
else:
combined_results[patent_number] = tfidf_similarities[idx]
# Sort and get top results
top_results = sorted(combined_results.items(), key=lambda x: x[1], reverse=True)[:top_k]
results = []
for patent_number, score in top_results:
if patent_number not in metadata:
print(f"Warning: Patent number {patent_number} not found in metadata")
continue
patent_data = metadata[patent_number]
result = f"Patent Number: {patent_number}\n"
text = patent_data.get('text', 'No text available')
result += f"Text: {text[:200]}...\n"
result += f"Combined Score: {score:.4f}\n\n"
results.append(result)
return "\n".join(results)
# Create Gradio interface
iface = gr.Interface(
fn=hybrid_search,
inputs=gr.Textbox(lines=2, placeholder="Enter your search query here..."),
outputs=gr.Textbox(lines=10, label="Search Results"),
title="Patent Similarity Search",
description="Enter a query to find similar patents based on their content."
)
if __name__ == "__main__":
iface.launch()
|