bhlewis commited on
Commit
3961929
·
verified ·
1 Parent(s): cf6f1f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -71
app.py CHANGED
@@ -49,9 +49,9 @@ embeddings, patent_numbers, metadata, texts = load_data()
49
 
50
  # Load BERT model for encoding search queries
51
  try:
52
- bert_model = AutoModel.from_pretrained('anferico/bert-for-patents')
53
  tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
54
- word_embedding_model = models.Transformer(bert_model, tokenizer)
 
55
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
56
  model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
57
  except Exception as e:
@@ -88,72 +88,4 @@ def extract_key_features(text):
88
 
89
  def compare_features(query_features, patent_features):
90
  common_features = set(query_features) & set(patent_features)
91
- similarity_score = len(common_features) / max(len(query_features), len(patent_features))
92
- return common_features, similarity_score
93
-
94
- def hybrid_search(query, top_k=5):
95
- print(f"Original query: {query}")
96
-
97
- query_features = extract_key_features(query)
98
-
99
- # Encode the query using the transformer model
100
- query_embedding = model.encode([query])[0]
101
- query_embedding = query_embedding / np.linalg.norm(query_embedding)
102
-
103
- # Perform semantic similarity search
104
- semantic_distances, semantic_indices = index.search(np.array([query_embedding]).astype('float32'), top_k * 2)
105
-
106
- # Perform TF-IDF based search
107
- query_tfidf = tfidf_vectorizer.transform([query])
108
- tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
109
- tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
110
-
111
- # Combine and rank results
112
- combined_results = {}
113
- for i, idx in enumerate(semantic_indices[0]):
114
- patent_number = patent_numbers[idx].decode('utf-8')
115
- text = metadata[patent_number]['text']
116
- patent_features = extract_key_features(text)
117
- common_features, feature_similarity = compare_features(query_features, patent_features)
118
- combined_results[patent_number] = {
119
- 'score': semantic_distances[0][i] * 1.5 + feature_similarity,
120
- 'common_features': common_features,
121
- 'text': text
122
- }
123
-
124
- for idx in tfidf_indices:
125
- patent_number = patent_numbers[idx].decode('utf-8')
126
- if patent_number not in combined_results:
127
- text = metadata[patent_number]['text']
128
- patent_features = extract_key_features(text)
129
- common_features, feature_similarity = compare_features(query_features, patent_features)
130
- combined_results[patent_number] = {
131
- 'score': tfidf_similarities[idx] + feature_similarity,
132
- 'common_features': common_features,
133
- 'text': text
134
- }
135
-
136
- # Sort and get top results
137
- top_results = sorted(combined_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k]
138
-
139
- results = []
140
- for patent_number, data in top_results:
141
- result = f"Patent Number: {patent_number}\n"
142
- result += f"Text: {data['text'][:200]}...\n"
143
- result += f"Combined Score: {data['score']:.4f}\n"
144
- result += f"Common Key Features: {', '.join(data['common_features'])}\n\n"
145
- results.append(result)
146
-
147
- return "\n".join(results)
148
-
149
- # Create Gradio interface
150
- iface = gr.Interface(
151
- fn=hybrid_search,
152
- inputs=gr.Textbox(lines=2, placeholder="Enter your patent query here..."),
153
- outputs=gr.Textbox(lines=10, label="Search Results"),
154
- title="Patent Similarity Search",
155
- description="Enter a patent description to find similar patents based on key features."
156
- )
157
-
158
- if __name__ == "__main__":
159
- iface.launch()
 
49
 
50
  # Load BERT model for encoding search queries
51
  try:
 
52
  tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
53
+ bert_model = AutoModel.from_pretrained('anferico/bert-for-patents')
54
+ word_embedding_model = models.Transformer(model=bert_model, tokenizer=tokenizer)
55
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
56
  model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
57
  except Exception as e:
 
88
 
89
  def compare_features(query_features, patent_features):
90
  common_features = set(query_features) & set(patent_features)
91
+ similarity_score = len(common_features) / max(len(query_features), len