|
from sentence_transformers import SentenceTransformer |
|
from torch.nn.functional import cosine_similarity |
|
import torch |
|
|
|
class SQLMetadataRetriever: |
|
def __init__(self): |
|
self.model = SentenceTransformer("all-MiniLM-L6-v2") |
|
self.docs = [] |
|
self.embeddings = None |
|
|
|
def add_documents(self, docs): |
|
"""Store and embed schema documents""" |
|
self.docs = docs |
|
self.embeddings = self.model.encode(docs, convert_to_tensor=True) |
|
|
|
def retrieve(self, query, top_k=1): |
|
query_embedding = self.model.encode(query, convert_to_tensor=True) |
|
|
|
if self.embeddings is None or self.embeddings.shape[0] == 0: |
|
raise ValueError("No embeddings found. Did you call add_documents()?") |
|
|
|
available_docs = self.embeddings.shape[0] |
|
top_k = min(top_k, available_docs) |
|
|
|
|
|
query_expanded = query_embedding.unsqueeze(0).expand(self.embeddings.size(0), -1) |
|
scores = cosine_similarity(query_expanded, self.embeddings, dim=1) |
|
|
|
|
|
top_indices = torch.topk(scores, top_k).indices.tolist() |
|
return [self.docs[i] for i in top_indices] |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
retriever = SQLMetadataRetriever() |
|
|
|
metadata_docs = [ |
|
|
|
"Table team: columns are id (Unique team identifier), full_name (Full team name, e.g., 'Los Angeles Lakers'), abbreviation (3-letter team code, e.g., 'LAL'), city, state, year_founded.", |
|
|
|
|
|
"Table game: columns are game_date (Date of the game), team_id_home, team_id_away (Unique IDs of home and away teams), team_name_home, team_name_away (Full names of the teams), pts_home, pts_away (Points scored), wl_home (W/L result), reb_home, reb_away (Total rebounds), ast_home, ast_away (Total assists), fgm_home, fg_pct_home (Field goals), fg3m_home (Three-pointers), ftm_home (Free throws), tov_home (Turnovers), and other game-related statistics." |
|
] |
|
|
|
|
|
retriever.add_documents(metadata_docs) |
|
|
|
question = "What is the most assists by the Celtics in a home game?" |
|
relevant = retriever.retrieve(question, top_k=1) |
|
print("Top match:", relevant[0]) |
|
|