Spaces:
Running
Running
Update database.py
Browse files- database.py +34 -22
database.py
CHANGED
@@ -5,6 +5,8 @@ import os
|
|
5 |
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
import numpy as np
|
7 |
from datasets import Dataset, load_dataset
|
|
|
|
|
8 |
|
9 |
# User-configurable variables
|
10 |
DB_NAME = "python_programs" # ChromaDB collection name
|
@@ -37,7 +39,7 @@ def store_program(client, code, sequence, vectors, collection_name=DB_NAME):
|
|
37 |
"""Store a program in ChromaDB with its code, sequence, and vectors."""
|
38 |
collection = create_collection(client, collection_name)
|
39 |
|
40 |
-
# Flatten vectors to ensure they are a list of numbers
|
41 |
flattened_vectors = [item for sublist in vectors for item in sublist]
|
42 |
|
43 |
# Store program data (ID, code, sequence, vectors)
|
@@ -77,10 +79,10 @@ def query_programs(client, operations, collection_name=DB_NAME, top_k=5, semanti
|
|
77 |
collection = create_collection(client, collection_name)
|
78 |
|
79 |
if semantic_query:
|
80 |
-
# Semantic search using
|
81 |
query_vector = generate_semantic_vector(semantic_query)
|
82 |
results = collection.query(
|
83 |
-
|
84 |
n_results=top_k,
|
85 |
include=["documents", "metadatas"]
|
86 |
)
|
@@ -99,8 +101,12 @@ def query_programs(client, operations, collection_name=DB_NAME, top_k=5, semanti
|
|
99 |
sequence = meta['sequence'].split(',')
|
100 |
if not semantic_query or is_subsequence(operations, sequence): # Ensure sequence match for operations
|
101 |
try:
|
|
|
102 |
doc_vectors = eval(doc['vectors']) if isinstance(doc['vectors'], str) else doc['vectors']
|
103 |
-
|
|
|
|
|
|
|
104 |
except:
|
105 |
program_vector = [0] * 6 # Fallback for malformed vectors
|
106 |
similarity = cosine_similarity([query_vector], [program_vector])[0][0] if program_vector and query_vector else 0
|
@@ -155,23 +161,29 @@ def generate_description_tokens(sequence, vectors):
|
|
155 |
tokens.append(f"span:{vec[3]:.2f}")
|
156 |
return tokens
|
157 |
|
158 |
-
def generate_semantic_vector(description):
|
159 |
-
"""Generate a semantic vector for a textual description
|
160 |
-
#
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
return vector
|
176 |
|
177 |
def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=HF_TOKEN):
|
@@ -184,7 +196,7 @@ def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=HF_TOKEN):
|
|
184 |
data = {
|
185 |
"code": results["documents"],
|
186 |
"sequence": [meta["sequence"] for meta in results["metadatas"]],
|
187 |
-
"vectors":
|
188 |
"description_tokens": [meta.get('description_tokens', '') for meta in results["metadatas"]]
|
189 |
}
|
190 |
|
|
|
5 |
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
import numpy as np
|
7 |
from datasets import Dataset, load_dataset
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
import torch
|
10 |
|
11 |
# User-configurable variables
|
12 |
DB_NAME = "python_programs" # ChromaDB collection name
|
|
|
39 |
"""Store a program in ChromaDB with its code, sequence, and vectors."""
|
40 |
collection = create_collection(client, collection_name)
|
41 |
|
42 |
+
# Flatten vectors to ensure they are a list of numbers (ChromaDB expects flat embeddings)
|
43 |
flattened_vectors = [item for sublist in vectors for item in sublist]
|
44 |
|
45 |
# Store program data (ID, code, sequence, vectors)
|
|
|
79 |
collection = create_collection(client, collection_name)
|
80 |
|
81 |
if semantic_query:
|
82 |
+
# Semantic search using CodeBERT embeddings
|
83 |
query_vector = generate_semantic_vector(semantic_query)
|
84 |
results = collection.query(
|
85 |
+
query_embeddings=[query_vector],
|
86 |
n_results=top_k,
|
87 |
include=["documents", "metadatas"]
|
88 |
)
|
|
|
101 |
sequence = meta['sequence'].split(',')
|
102 |
if not semantic_query or is_subsequence(operations, sequence): # Ensure sequence match for operations
|
103 |
try:
|
104 |
+
# Reconstruct program vectors (flatten if needed)
|
105 |
doc_vectors = eval(doc['vectors']) if isinstance(doc['vectors'], str) else doc['vectors']
|
106 |
+
if isinstance(doc_vectors, (list, np.ndarray)) and len(doc_vectors) == 6:
|
107 |
+
program_vector = doc_vectors # Single flat vector
|
108 |
+
else:
|
109 |
+
program_vector = np.mean([v for v in doc_vectors if isinstance(v, (list, np.ndarray))], axis=0).tolist()
|
110 |
except:
|
111 |
program_vector = [0] * 6 # Fallback for malformed vectors
|
112 |
similarity = cosine_similarity([query_vector], [program_vector])[0][0] if program_vector and query_vector else 0
|
|
|
161 |
tokens.append(f"span:{vec[3]:.2f}")
|
162 |
return tokens
|
163 |
|
164 |
+
def generate_semantic_vector(description, use_gpu=False):
|
165 |
+
"""Generate a semantic vector for a textual description using CodeBERT, with CPU/GPU option."""
|
166 |
+
# Load CodeBERT model and tokenizer
|
167 |
+
model_name = "microsoft/codebert-base"
|
168 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
169 |
+
device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
|
170 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
171 |
+
|
172 |
+
# Tokenize and encode the description
|
173 |
+
inputs = tokenizer(description, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
174 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
175 |
+
|
176 |
+
# Generate embeddings
|
177 |
+
with torch.no_grad():
|
178 |
+
outputs = model(**inputs)
|
179 |
+
# Use mean pooling of the last hidden states
|
180 |
+
vector = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().tolist()
|
181 |
+
|
182 |
+
# Truncate or pad to 6D to match our vectors
|
183 |
+
if len(vector) < 6:
|
184 |
+
vector.extend([0] * (6 - len(vector)))
|
185 |
+
elif len(vector) > 6:
|
186 |
+
vector = vector[:6]
|
187 |
return vector
|
188 |
|
189 |
def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=HF_TOKEN):
|
|
|
196 |
data = {
|
197 |
"code": results["documents"],
|
198 |
"sequence": [meta["sequence"] for meta in results["metadatas"]],
|
199 |
+
"vectors": results["embeddings"], # ChromaDB already flattens embeddings
|
200 |
"description_tokens": [meta.get('description_tokens', '') for meta in results["metadatas"]]
|
201 |
}
|
202 |
|