Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import faiss
|
|
5 |
import json
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
|
8 |
-
# Load embeddings and metadata
|
9 |
def load_data():
|
10 |
with h5py.File('patent_embeddings.h5', 'r') as f:
|
11 |
embeddings = f['embeddings'][:]
|
@@ -17,6 +16,7 @@ def load_data():
|
|
17 |
data = json.loads(line)
|
18 |
metadata[data['patent_number']] = data
|
19 |
|
|
|
20 |
return embeddings, patent_numbers, metadata
|
21 |
|
22 |
embeddings, patent_numbers, metadata = load_data()
|
@@ -26,12 +26,30 @@ index = faiss.IndexFlatL2(embeddings.shape[1])
|
|
26 |
index.add(embeddings)
|
27 |
|
28 |
# Load BERT model for encoding search queries
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def search(query, top_k=5):
|
32 |
# Encode the query
|
33 |
query_embedding = model.encode([query])[0]
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# Perform similarity search
|
36 |
distances, indices = index.search(np.array([query_embedding]), top_k)
|
37 |
|
@@ -55,4 +73,5 @@ iface = gr.Interface(
|
|
55 |
description="Enter a query to find similar patents based on their embeddings."
|
56 |
)
|
57 |
|
58 |
-
|
|
|
|
5 |
import json
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
|
|
|
8 |
def load_data():
|
9 |
with h5py.File('patent_embeddings.h5', 'r') as f:
|
10 |
embeddings = f['embeddings'][:]
|
|
|
16 |
data = json.loads(line)
|
17 |
metadata[data['patent_number']] = data
|
18 |
|
19 |
+
print(f"Embedding shape: {embeddings.shape}")
|
20 |
return embeddings, patent_numbers, metadata
|
21 |
|
22 |
embeddings, patent_numbers, metadata = load_data()
|
|
|
26 |
index.add(embeddings)
|
27 |
|
28 |
# Load BERT model for encoding search queries
|
29 |
+
embedding_dim = embeddings.shape[1]
|
30 |
+
print(f"Embedding dimension: {embedding_dim}")
|
31 |
+
|
32 |
+
if embedding_dim == 384:
|
33 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
34 |
+
elif embedding_dim == 768:
|
35 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
36 |
+
else:
|
37 |
+
print(f"Unexpected embedding dimension: {embedding_dim}")
|
38 |
+
model = SentenceTransformer('all-MiniLM-L6-v2') # Default to this model
|
39 |
|
40 |
def search(query, top_k=5):
|
41 |
# Encode the query
|
42 |
query_embedding = model.encode([query])[0]
|
43 |
|
44 |
+
# Ensure the query embedding has the same dimension as the index
|
45 |
+
if query_embedding.shape[0] != index.d:
|
46 |
+
print(f"Query embedding dimension ({query_embedding.shape[0]}) does not match index dimension ({index.d})")
|
47 |
+
# Option 1: Pad or truncate the query embedding
|
48 |
+
if query_embedding.shape[0] < index.d:
|
49 |
+
query_embedding = np.pad(query_embedding, (0, index.d - query_embedding.shape[0]))
|
50 |
+
else:
|
51 |
+
query_embedding = query_embedding[:index.d]
|
52 |
+
|
53 |
# Perform similarity search
|
54 |
distances, indices = index.search(np.array([query_embedding]), top_k)
|
55 |
|
|
|
73 |
description="Enter a query to find similar patents based on their embeddings."
|
74 |
)
|
75 |
|
76 |
+
if __name__ == "__main__":
|
77 |
+
iface.launch(share=True)
|