MikeTerekhov commited on
Commit
6cb0a90
·
verified ·
1 Parent(s): 03182a1

Create rag_metadata.py

Browse files
Files changed (1) hide show
  1. rag_metadata.py +51 -0
rag_metadata.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from torch.nn.functional import cosine_similarity
3
+ import torch
4
+
5
+ class SQLMetadataRetriever:
6
+ def __init__(self):
7
+ self.model = SentenceTransformer("all-MiniLM-L6-v2")
8
+ self.docs = []
9
+ self.embeddings = None
10
+
11
+ def add_documents(self, docs):
12
+ """Store and embed schema documents"""
13
+ self.docs = docs
14
+ self.embeddings = self.model.encode(docs, convert_to_tensor=True)
15
+
16
+ def retrieve(self, query, top_k=1):
17
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
18
+
19
+ if self.embeddings is None or self.embeddings.shape[0] == 0:
20
+ raise ValueError("No embeddings found. Did you call add_documents()?")
21
+
22
+ available_docs = self.embeddings.shape[0]
23
+ top_k = min(top_k, available_docs)
24
+
25
+ # Explicitly expand the query embedding to match the number of documents
26
+ query_expanded = query_embedding.unsqueeze(0).expand(self.embeddings.size(0), -1)
27
+ scores = cosine_similarity(query_expanded, self.embeddings, dim=1)
28
+
29
+ # Now scores should be a 1D tensor with length equal to available_docs
30
+ top_indices = torch.topk(scores, top_k).indices.tolist()
31
+ return [self.docs[i] for i in top_indices]
32
+
33
+
34
+ # Example usage:
35
+ if __name__ == "__main__":
36
+ retriever = SQLMetadataRetriever()
37
+
38
+ metadata_docs = [
39
+ # Table: team
40
+ "Table team: columns are id (Unique team identifier), full_name (Full team name, e.g., 'Los Angeles Lakers'), abbreviation (3-letter team code, e.g., 'LAL'), city, state, year_founded.",
41
+
42
+ # Table: game
43
+ "Table game: columns are game_date (Date of the game), team_id_home, team_id_away (Unique IDs of home and away teams), team_name_home, team_name_away (Full names of the teams), pts_home, pts_away (Points scored), wl_home (W/L result), reb_home, reb_away (Total rebounds), ast_home, ast_away (Total assists), fgm_home, fg_pct_home (Field goals), fg3m_home (Three-pointers), ftm_home (Free throws), tov_home (Turnovers), and other game-related statistics."
44
+ ]
45
+
46
+
47
+ retriever.add_documents(metadata_docs)
48
+
49
+ question = "What is the most assists by the Celtics in a home game?"
50
+ relevant = retriever.retrieve(question, top_k=1)
51
+ print("Top match:", relevant[0])