bhlewis commited on
Commit
5c704be
1 Parent(s): 3e02536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -15
app.py CHANGED
@@ -4,6 +4,8 @@ import h5py
4
  import faiss
5
  import json
6
  from sentence_transformers import SentenceTransformer
 
 
7
 
8
  def load_data():
9
  try:
@@ -12,16 +14,18 @@ def load_data():
12
  patent_numbers = f['patent_numbers'][:]
13
 
14
  metadata = {}
 
15
  with open('patent_metadata.jsonl', 'r') as f:
16
  for line in f:
17
  data = json.loads(line)
18
  metadata[data['patent_number']] = data
 
19
 
20
  print(f"Embedding shape: {embeddings.shape}")
21
  print(f"Number of patent numbers: {len(patent_numbers)}")
22
  print(f"Number of metadata entries: {len(metadata)}")
23
 
24
- return embeddings, patent_numbers, metadata
25
  except FileNotFoundError as e:
26
  print(f"Error: Could not find file. {e}")
27
  raise
@@ -29,7 +33,7 @@ def load_data():
29
  print(f"An unexpected error occurred while loading data: {e}")
30
  raise
31
 
32
- embeddings, patent_numbers, metadata = load_data()
33
 
34
  # Normalize embeddings for cosine similarity
35
  embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
@@ -41,23 +45,43 @@ index.add(embeddings)
41
  # Load BERT model for encoding search queries
42
  model = SentenceTransformer('all-mpnet-base-v2')
43
 
44
- def search(query, top_k=5):
 
 
 
 
45
  print(f"Searching for: {query}")
46
 
47
- # Encode the query
48
  query_embedding = model.encode([query])[0]
49
  query_embedding = query_embedding / np.linalg.norm(query_embedding)
50
 
51
- print(f"Query embedding shape: {query_embedding.shape}")
 
52
 
53
- # Perform similarity search
54
- distances, indices = index.search(np.array([query_embedding]), top_k)
 
 
55
 
56
- print(f"FAISS search results - Distances: {distances}, Indices: {indices}")
 
 
 
 
57
 
58
- results = []
59
- for i, idx in enumerate(indices[0]):
60
  patent_number = patent_numbers[idx].decode('utf-8')
 
 
 
 
 
 
 
 
 
 
61
  if patent_number not in metadata:
62
  print(f"Warning: Patent number {patent_number} not found in metadata")
63
  continue
@@ -65,20 +89,19 @@ def search(query, top_k=5):
65
  result = f"Patent Number: {patent_number}\n"
66
  text = patent_data.get('text', 'No text available')
67
  result += f"Text: {text[:200]}...\n"
68
- result += f"Similarity Score: {distances[0][i]:.4f}\n\n"
69
  results.append(result)
70
 
71
- return "\n".join(results[:top_k])
72
 
73
  # Create Gradio interface
74
  iface = gr.Interface(
75
- fn=search,
76
  inputs=gr.Textbox(lines=2, placeholder="Enter your search query here..."),
77
  outputs=gr.Textbox(lines=10, label="Search Results"),
78
  title="Patent Similarity Search",
79
- description="Enter a query to find similar patents based on their embeddings."
80
  )
81
 
82
  if __name__ == "__main__":
83
  iface.launch()
84
-
 
4
  import faiss
5
  import json
6
  from sentence_transformers import SentenceTransformer
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
 
10
  def load_data():
11
  try:
 
14
  patent_numbers = f['patent_numbers'][:]
15
 
16
  metadata = {}
17
+ texts = []
18
  with open('patent_metadata.jsonl', 'r') as f:
19
  for line in f:
20
  data = json.loads(line)
21
  metadata[data['patent_number']] = data
22
+ texts.append(data['text'])
23
 
24
  print(f"Embedding shape: {embeddings.shape}")
25
  print(f"Number of patent numbers: {len(patent_numbers)}")
26
  print(f"Number of metadata entries: {len(metadata)}")
27
 
28
+ return embeddings, patent_numbers, metadata, texts
29
  except FileNotFoundError as e:
30
  print(f"Error: Could not find file. {e}")
31
  raise
 
33
  print(f"An unexpected error occurred while loading data: {e}")
34
  raise
35
 
36
+ embeddings, patent_numbers, metadata, texts = load_data()
37
 
38
  # Normalize embeddings for cosine similarity
39
  embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
 
45
  # Load BERT model for encoding search queries
46
  model = SentenceTransformer('all-mpnet-base-v2')
47
 
48
+ # Create TF-IDF vectorizer
49
+ tfidf_vectorizer = TfidfVectorizer(stop_words='english')
50
+ tfidf_matrix = tfidf_vectorizer.fit_transform(texts)
51
+
52
+ def hybrid_search(query, top_k=5):
53
  print(f"Searching for: {query}")
54
 
55
+ # Encode the query using the transformer model
56
  query_embedding = model.encode([query])[0]
57
  query_embedding = query_embedding / np.linalg.norm(query_embedding)
58
 
59
+ # Perform semantic similarity search
60
+ semantic_distances, semantic_indices = index.search(np.array([query_embedding]), top_k * 2)
61
 
62
+ # Perform TF-IDF based search
63
+ query_tfidf = tfidf_vectorizer.transform([query])
64
+ tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
65
+ tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
66
 
67
+ # Combine and rank results
68
+ combined_results = {}
69
+ for i, idx in enumerate(semantic_indices[0]):
70
+ patent_number = patent_numbers[idx].decode('utf-8')
71
+ combined_results[patent_number] = semantic_distances[0][i]
72
 
73
+ for idx in tfidf_indices:
 
74
  patent_number = patent_numbers[idx].decode('utf-8')
75
+ if patent_number in combined_results:
76
+ combined_results[patent_number] += tfidf_similarities[idx]
77
+ else:
78
+ combined_results[patent_number] = tfidf_similarities[idx]
79
+
80
+ # Sort and get top results
81
+ top_results = sorted(combined_results.items(), key=lambda x: x[1], reverse=True)[:top_k]
82
+
83
+ results = []
84
+ for patent_number, score in top_results:
85
  if patent_number not in metadata:
86
  print(f"Warning: Patent number {patent_number} not found in metadata")
87
  continue
 
89
  result = f"Patent Number: {patent_number}\n"
90
  text = patent_data.get('text', 'No text available')
91
  result += f"Text: {text[:200]}...\n"
92
+ result += f"Combined Score: {score:.4f}\n\n"
93
  results.append(result)
94
 
95
+ return "\n".join(results)
96
 
97
  # Create Gradio interface
98
  iface = gr.Interface(
99
+ fn=hybrid_search,
100
  inputs=gr.Textbox(lines=2, placeholder="Enter your search query here..."),
101
  outputs=gr.Textbox(lines=10, label="Search Results"),
102
  title="Patent Similarity Search",
103
+ description="Enter a query to find similar patents based on their content."
104
  )
105
 
106
  if __name__ == "__main__":
107
  iface.launch()