File size: 2,303 Bytes
6cb0a90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from sentence_transformers import SentenceTransformer
from torch.nn.functional import cosine_similarity
import torch
class SQLMetadataRetriever:
def __init__(self):
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.docs = []
self.embeddings = None
def add_documents(self, docs):
"""Store and embed schema documents"""
self.docs = docs
self.embeddings = self.model.encode(docs, convert_to_tensor=True)
def retrieve(self, query, top_k=1):
query_embedding = self.model.encode(query, convert_to_tensor=True)
if self.embeddings is None or self.embeddings.shape[0] == 0:
raise ValueError("No embeddings found. Did you call add_documents()?")
available_docs = self.embeddings.shape[0]
top_k = min(top_k, available_docs)
# Explicitly expand the query embedding to match the number of documents
query_expanded = query_embedding.unsqueeze(0).expand(self.embeddings.size(0), -1)
scores = cosine_similarity(query_expanded, self.embeddings, dim=1)
# Now scores should be a 1D tensor with length equal to available_docs
top_indices = torch.topk(scores, top_k).indices.tolist()
return [self.docs[i] for i in top_indices]
# Example usage:
if __name__ == "__main__":
retriever = SQLMetadataRetriever()
metadata_docs = [
# Table: team
"Table team: columns are id (Unique team identifier), full_name (Full team name, e.g., 'Los Angeles Lakers'), abbreviation (3-letter team code, e.g., 'LAL'), city, state, year_founded.",
# Table: game
"Table game: columns are game_date (Date of the game), team_id_home, team_id_away (Unique IDs of home and away teams), team_name_home, team_name_away (Full names of the teams), pts_home, pts_away (Points scored), wl_home (W/L result), reb_home, reb_away (Total rebounds), ast_home, ast_away (Total assists), fgm_home, fg_pct_home (Field goals), fg3m_home (Three-pointers), ftm_home (Free throws), tov_home (Turnovers), and other game-related statistics."
]
retriever.add_documents(metadata_docs)
question = "What is the most assists by the Celtics in a home game?"
relevant = retriever.retrieve(question, top_k=1)
print("Top match:", relevant[0])
|