bhlewis commited on
Commit
4a2057c
1 Parent(s): 086c46f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -38
app.py CHANGED
@@ -20,53 +20,40 @@ def load_data():
20
  print(f"Number of patent numbers: {len(patent_numbers)}")
21
  print(f"Number of metadata entries: {len(metadata)}")
22
 
23
- # Print sample metadata
24
- sample_patent = next(iter(metadata))
25
- print(f"Sample metadata for patent {sample_patent}:")
26
- print(json.dumps(metadata[sample_patent], indent=2))
27
-
28
  return embeddings, patent_numbers, metadata
29
 
30
  embeddings, patent_numbers, metadata = load_data()
31
 
32
- # Create FAISS index
33
- index = faiss.IndexFlatL2(embeddings.shape[1])
 
 
 
34
  index.add(embeddings)
35
 
36
  # Load BERT model for encoding search queries
37
- embedding_dim = embeddings.shape[1]
38
- print(f"Embedding dimension: {embedding_dim}")
39
 
40
- if embedding_dim == 384:
41
- model = SentenceTransformer('all-MiniLM-L6-v2')
42
- elif embedding_dim == 768:
43
- model = SentenceTransformer('all-mpnet-base-v2')
44
- else:
45
- print(f"Unexpected embedding dimension: {embedding_dim}")
46
- model = SentenceTransformer('all-MiniLM-L6-v2') # Default to this model
47
 
48
  def search(query, top_k=5):
 
 
49
  # Encode the query
50
  query_embedding = model.encode([query])[0]
 
51
 
52
- # Debug: Print query embedding
53
- print(f"Query embedding: {query_embedding}")
54
-
55
- # Ensure the query embedding has the same dimension as the index
56
- if query_embedding.shape[0] != index.d:
57
- print(f"Query embedding dimension ({query_embedding.shape[0]}) does not match index dimension ({index.d})")
58
- # Option 1: Pad or truncate the query embedding
59
- if query_embedding.shape[0] < index.d:
60
- query_embedding = np.pad(query_embedding, (0, index.d - query_embedding.shape[0]))
61
- else:
62
- query_embedding = query_embedding[:index.d]
63
 
64
  # Perform similarity search
65
  distances, indices = index.search(np.array([query_embedding]), top_k)
66
 
67
- # Debug: Print distances and indices
68
- print(f"Distances: {distances}")
69
- print(f"Indices: {indices}")
70
 
71
  results = []
72
  for i, idx in enumerate(indices[0]):
@@ -76,17 +63,23 @@ def search(query, top_k=5):
76
  continue
77
  patent_data = metadata[patent_number]
78
  result = f"Patent Number: {patent_number}\n"
79
-
80
- # Safely extract abstract
81
  abstract = patent_data.get('abstract', 'No abstract available')
82
- if isinstance(abstract, str):
83
- result += f"Abstract: {abstract[:200]}...\n"
84
- else:
85
- result += f"Abstract: Unable to display (type: {type(abstract)})\n"
86
-
87
- result += f"Similarity Score: {1 - distances[0][i]:.4f}\n\n"
88
  results.append(result)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  return "\n".join(results)
91
 
92
  # Create Gradio interface
 
20
  print(f"Number of patent numbers: {len(patent_numbers)}")
21
  print(f"Number of metadata entries: {len(metadata)}")
22
 
 
 
 
 
 
23
  return embeddings, patent_numbers, metadata
24
 
25
  embeddings, patent_numbers, metadata = load_data()
26
 
27
+ # Normalize embeddings for cosine similarity
28
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
29
+
30
+ # Create FAISS index for cosine similarity
31
+ index = faiss.IndexFlatIP(embeddings.shape[1])
32
  index.add(embeddings)
33
 
34
  # Load BERT model for encoding search queries
35
+ model = SentenceTransformer('all-mpnet-base-v2')
 
36
 
37
+ def exact_text_match(query, metadata):
38
+ matches = []
39
+ for patent_number, data in metadata.items():
40
+ if query.lower() in data.get('abstract', '').lower() or query.lower() in data.get('claims', '').lower():
41
+ matches.append((patent_number, 1.0)) # Score of 1.0 for exact match
42
+ return matches
 
43
 
44
  def search(query, top_k=5):
45
+ print(f"Searching for: {query}")
46
+
47
  # Encode the query
48
  query_embedding = model.encode([query])[0]
49
+ query_embedding = query_embedding / np.linalg.norm(query_embedding)
50
 
51
+ print(f"Query embedding shape: {query_embedding.shape}")
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Perform similarity search
54
  distances, indices = index.search(np.array([query_embedding]), top_k)
55
 
56
+ print(f"FAISS search results - Distances: {distances}, Indices: {indices}")
 
 
57
 
58
  results = []
59
  for i, idx in enumerate(indices[0]):
 
63
  continue
64
  patent_data = metadata[patent_number]
65
  result = f"Patent Number: {patent_number}\n"
 
 
66
  abstract = patent_data.get('abstract', 'No abstract available')
67
+ result += f"Abstract: {abstract[:200]}...\n"
68
+ result += f"Similarity Score: {distances[0][i]:.4f}\n\n"
 
 
 
 
69
  results.append(result)
70
 
71
+ # Fallback to exact text match if no results or low similarity
72
+ if not results or distances[0][0] < 0.5:
73
+ print("Falling back to exact text match")
74
+ exact_matches = exact_text_match(query, metadata)
75
+ for patent_number, score in exact_matches[:top_k]:
76
+ patent_data = metadata[patent_number]
77
+ result = f"Patent Number: {patent_number}\n"
78
+ abstract = patent_data.get('abstract', 'No abstract available')
79
+ result += f"Abstract: {abstract[:200]}...\n"
80
+ result += f"Exact Match Score: {score:.4f}\n\n"
81
+ results.append(result)
82
+
83
  return "\n".join(results)
84
 
85
  # Create Gradio interface