File size: 2,303 Bytes
6cb0a90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from sentence_transformers import SentenceTransformer
from torch.nn.functional import cosine_similarity
import torch

class SQLMetadataRetriever:
    def __init__(self):
        self.model = SentenceTransformer("all-MiniLM-L6-v2")
        self.docs = []
        self.embeddings = None

    def add_documents(self, docs):
        """Store and embed schema documents"""
        self.docs = docs
        self.embeddings = self.model.encode(docs, convert_to_tensor=True)

    def retrieve(self, query, top_k=1):
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        
        if self.embeddings is None or self.embeddings.shape[0] == 0:
            raise ValueError("No embeddings found. Did you call add_documents()?")

        available_docs = self.embeddings.shape[0]
        top_k = min(top_k, available_docs)

        # Explicitly expand the query embedding to match the number of documents
        query_expanded = query_embedding.unsqueeze(0).expand(self.embeddings.size(0), -1)
        scores = cosine_similarity(query_expanded, self.embeddings, dim=1)
        
        # Now scores should be a 1D tensor with length equal to available_docs
        top_indices = torch.topk(scores, top_k).indices.tolist()
        return [self.docs[i] for i in top_indices]


# Example usage:
if __name__ == "__main__":
    retriever = SQLMetadataRetriever()

    metadata_docs = [
    # Table: team
    "Table team: columns are id (Unique team identifier), full_name (Full team name, e.g., 'Los Angeles Lakers'), abbreviation (3-letter team code, e.g., 'LAL'), city, state, year_founded.",
    
    # Table: game
    "Table game: columns are game_date (Date of the game), team_id_home, team_id_away (Unique IDs of home and away teams), team_name_home, team_name_away (Full names of the teams), pts_home, pts_away (Points scored), wl_home (W/L result), reb_home, reb_away (Total rebounds), ast_home, ast_away (Total assists), fgm_home, fg_pct_home (Field goals), fg3m_home (Three-pointers), ftm_home (Free throws), tov_home (Turnovers), and other game-related statistics."
    ]


    retriever.add_documents(metadata_docs)

    question = "What is the most assists by the Celtics in a home game?"
    relevant = retriever.retrieve(question, top_k=1)
    print("Top match:", relevant[0])