File size: 3,899 Bytes
663d1ad
 
 
 
 
 
5c704be
 
663d1ad
 
6744e1a
 
 
 
 
 
5c704be
6744e1a
 
 
 
5c704be
6744e1a
 
 
 
 
5c704be
6744e1a
 
 
 
 
 
663d1ad
5c704be
663d1ad
4a2057c
 
 
 
 
663d1ad
 
 
4a2057c
74523b8
5c704be
 
 
 
 
4a2057c
 
5c704be
663d1ad
4a2057c
663d1ad
5c704be
 
74523b8
5c704be
 
 
 
663d1ad
5c704be
 
 
 
 
086c46f
5c704be
13d437f
5c704be
 
 
 
 
 
 
 
 
 
c498c82
 
 
663d1ad
 
13d437f
 
5c704be
663d1ad
 
5c704be
663d1ad
 
 
5c704be
663d1ad
 
 
5c704be
663d1ad
 
74523b8
13d437f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import numpy as np
import h5py
import faiss
import json
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

def load_data():
    try:
        with h5py.File('patent_embeddings.h5', 'r') as f:
            embeddings = f['embeddings'][:]
            patent_numbers = f['patent_numbers'][:]
        
        metadata = {}
        texts = []
        with open('patent_metadata.jsonl', 'r') as f:
            for line in f:
                data = json.loads(line)
                metadata[data['patent_number']] = data
                texts.append(data['text'])
        
        print(f"Embedding shape: {embeddings.shape}")
        print(f"Number of patent numbers: {len(patent_numbers)}")
        print(f"Number of metadata entries: {len(metadata)}")
        
        return embeddings, patent_numbers, metadata, texts
    except FileNotFoundError as e:
        print(f"Error: Could not find file. {e}")
        raise
    except Exception as e:
        print(f"An unexpected error occurred while loading data: {e}")
        raise

embeddings, patent_numbers, metadata, texts = load_data()

# Normalize embeddings for cosine similarity
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

# Create FAISS index for cosine similarity
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)

# Load BERT model for encoding search queries
model = SentenceTransformer('all-mpnet-base-v2')

# Create TF-IDF vectorizer
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
tfidf_matrix = tfidf_vectorizer.fit_transform(texts)

def hybrid_search(query, top_k=5):
    print(f"Searching for: {query}")
    
    # Encode the query using the transformer model
    query_embedding = model.encode([query])[0]
    query_embedding = query_embedding / np.linalg.norm(query_embedding)
    
    # Perform semantic similarity search
    semantic_distances, semantic_indices = index.search(np.array([query_embedding]), top_k * 2)
    
    # Perform TF-IDF based search
    query_tfidf = tfidf_vectorizer.transform([query])
    tfidf_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
    tfidf_indices = tfidf_similarities.argsort()[-top_k * 2:][::-1]
    
    # Combine and rank results
    combined_results = {}
    for i, idx in enumerate(semantic_indices[0]):
        patent_number = patent_numbers[idx].decode('utf-8')
        combined_results[patent_number] = semantic_distances[0][i]
    
    for idx in tfidf_indices:
        patent_number = patent_numbers[idx].decode('utf-8')
        if patent_number in combined_results:
            combined_results[patent_number] += tfidf_similarities[idx]
        else:
            combined_results[patent_number] = tfidf_similarities[idx]
    
    # Sort and get top results
    top_results = sorted(combined_results.items(), key=lambda x: x[1], reverse=True)[:top_k]
    
    results = []
    for patent_number, score in top_results:
        if patent_number not in metadata:
            print(f"Warning: Patent number {patent_number} not found in metadata")
            continue
        patent_data = metadata[patent_number]
        result = f"Patent Number: {patent_number}\n"
        text = patent_data.get('text', 'No text available')
        result += f"Text: {text[:200]}...\n"
        result += f"Combined Score: {score:.4f}\n\n"
        results.append(result)
    
    return "\n".join(results)

# Create Gradio interface
iface = gr.Interface(
    fn=hybrid_search,
    inputs=gr.Textbox(lines=2, placeholder="Enter your search query here..."),
    outputs=gr.Textbox(lines=10, label="Search Results"),
    title="Patent Similarity Search",
    description="Enter a query to find similar patents based on their content."
)

if __name__ == "__main__":
    iface.launch()