Update app.py
Browse files
app.py
CHANGED
@@ -20,53 +20,40 @@ def load_data():
|
|
20 |
print(f"Number of patent numbers: {len(patent_numbers)}")
|
21 |
print(f"Number of metadata entries: {len(metadata)}")
|
22 |
|
23 |
-
# Print sample metadata
|
24 |
-
sample_patent = next(iter(metadata))
|
25 |
-
print(f"Sample metadata for patent {sample_patent}:")
|
26 |
-
print(json.dumps(metadata[sample_patent], indent=2))
|
27 |
-
|
28 |
return embeddings, patent_numbers, metadata
|
29 |
|
30 |
embeddings, patent_numbers, metadata = load_data()
|
31 |
|
32 |
-
#
|
33 |
-
|
|
|
|
|
|
|
34 |
index.add(embeddings)
|
35 |
|
36 |
# Load BERT model for encoding search queries
|
37 |
-
|
38 |
-
print(f"Embedding dimension: {embedding_dim}")
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
model = SentenceTransformer('all-MiniLM-L6-v2') # Default to this model
|
47 |
|
48 |
def search(query, top_k=5):
|
|
|
|
|
49 |
# Encode the query
|
50 |
query_embedding = model.encode([query])[0]
|
|
|
51 |
|
52 |
-
|
53 |
-
print(f"Query embedding: {query_embedding}")
|
54 |
-
|
55 |
-
# Ensure the query embedding has the same dimension as the index
|
56 |
-
if query_embedding.shape[0] != index.d:
|
57 |
-
print(f"Query embedding dimension ({query_embedding.shape[0]}) does not match index dimension ({index.d})")
|
58 |
-
# Option 1: Pad or truncate the query embedding
|
59 |
-
if query_embedding.shape[0] < index.d:
|
60 |
-
query_embedding = np.pad(query_embedding, (0, index.d - query_embedding.shape[0]))
|
61 |
-
else:
|
62 |
-
query_embedding = query_embedding[:index.d]
|
63 |
|
64 |
# Perform similarity search
|
65 |
distances, indices = index.search(np.array([query_embedding]), top_k)
|
66 |
|
67 |
-
|
68 |
-
print(f"Distances: {distances}")
|
69 |
-
print(f"Indices: {indices}")
|
70 |
|
71 |
results = []
|
72 |
for i, idx in enumerate(indices[0]):
|
@@ -76,17 +63,23 @@ def search(query, top_k=5):
|
|
76 |
continue
|
77 |
patent_data = metadata[patent_number]
|
78 |
result = f"Patent Number: {patent_number}\n"
|
79 |
-
|
80 |
-
# Safely extract abstract
|
81 |
abstract = patent_data.get('abstract', 'No abstract available')
|
82 |
-
|
83 |
-
|
84 |
-
else:
|
85 |
-
result += f"Abstract: Unable to display (type: {type(abstract)})\n"
|
86 |
-
|
87 |
-
result += f"Similarity Score: {1 - distances[0][i]:.4f}\n\n"
|
88 |
results.append(result)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
return "\n".join(results)
|
91 |
|
92 |
# Create Gradio interface
|
|
|
20 |
print(f"Number of patent numbers: {len(patent_numbers)}")
|
21 |
print(f"Number of metadata entries: {len(metadata)}")
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
return embeddings, patent_numbers, metadata
|
24 |
|
25 |
embeddings, patent_numbers, metadata = load_data()
|
26 |
|
27 |
+
# Normalize embeddings for cosine similarity
|
28 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
29 |
+
|
30 |
+
# Create FAISS index for cosine similarity
|
31 |
+
index = faiss.IndexFlatIP(embeddings.shape[1])
|
32 |
index.add(embeddings)
|
33 |
|
34 |
# Load BERT model for encoding search queries
|
35 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
|
|
36 |
|
37 |
+
def exact_text_match(query, metadata):
|
38 |
+
matches = []
|
39 |
+
for patent_number, data in metadata.items():
|
40 |
+
if query.lower() in data.get('abstract', '').lower() or query.lower() in data.get('claims', '').lower():
|
41 |
+
matches.append((patent_number, 1.0)) # Score of 1.0 for exact match
|
42 |
+
return matches
|
|
|
43 |
|
44 |
def search(query, top_k=5):
|
45 |
+
print(f"Searching for: {query}")
|
46 |
+
|
47 |
# Encode the query
|
48 |
query_embedding = model.encode([query])[0]
|
49 |
+
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
50 |
|
51 |
+
print(f"Query embedding shape: {query_embedding.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Perform similarity search
|
54 |
distances, indices = index.search(np.array([query_embedding]), top_k)
|
55 |
|
56 |
+
print(f"FAISS search results - Distances: {distances}, Indices: {indices}")
|
|
|
|
|
57 |
|
58 |
results = []
|
59 |
for i, idx in enumerate(indices[0]):
|
|
|
63 |
continue
|
64 |
patent_data = metadata[patent_number]
|
65 |
result = f"Patent Number: {patent_number}\n"
|
|
|
|
|
66 |
abstract = patent_data.get('abstract', 'No abstract available')
|
67 |
+
result += f"Abstract: {abstract[:200]}...\n"
|
68 |
+
result += f"Similarity Score: {distances[0][i]:.4f}\n\n"
|
|
|
|
|
|
|
|
|
69 |
results.append(result)
|
70 |
|
71 |
+
# Fallback to exact text match if no results or low similarity
|
72 |
+
if not results or distances[0][0] < 0.5:
|
73 |
+
print("Falling back to exact text match")
|
74 |
+
exact_matches = exact_text_match(query, metadata)
|
75 |
+
for patent_number, score in exact_matches[:top_k]:
|
76 |
+
patent_data = metadata[patent_number]
|
77 |
+
result = f"Patent Number: {patent_number}\n"
|
78 |
+
abstract = patent_data.get('abstract', 'No abstract available')
|
79 |
+
result += f"Abstract: {abstract[:200]}...\n"
|
80 |
+
result += f"Exact Match Score: {score:.4f}\n\n"
|
81 |
+
results.append(result)
|
82 |
+
|
83 |
return "\n".join(results)
|
84 |
|
85 |
# Create Gradio interface
|