Spaces:
Running
Running
Update database.py
Browse files- database.py +82 -22
database.py
CHANGED
@@ -1,25 +1,36 @@
|
|
1 |
# database.py
|
2 |
-
import
|
3 |
-
import os
|
4 |
from parser import parse_python_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
def
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
def populate_sample_db():
|
23 |
# Sample programs for testing
|
24 |
samples = [
|
25 |
"""
|
@@ -39,8 +50,57 @@ def populate_sample_db():
|
|
39 |
for code in samples:
|
40 |
parts, sequence = parse_python_code(code)
|
41 |
vectors = [part['vector'] for part in parts]
|
42 |
-
store_program(code, sequence, vectors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
if __name__ == '__main__':
|
45 |
-
|
46 |
-
populate_sample_db()
|
|
|
1 |
# database.py
|
2 |
+
import chromadb
|
|
|
3 |
from parser import parse_python_code
|
4 |
+
import os
|
5 |
+
|
6 |
+
def init_chromadb():
|
7 |
+
# Initialize ChromaDB client (in-memory for now, can persist to disk)
|
8 |
+
client = chromadb.Client()
|
9 |
+
return client
|
10 |
+
|
11 |
+
def create_collection(client, collection_name="python_programs"):
|
12 |
+
# Create or get a collection for Python programs
|
13 |
+
try:
|
14 |
+
collection = client.get_collection(name=collection_name)
|
15 |
+
except:
|
16 |
+
collection = client.create_collection(name=collection_name)
|
17 |
+
return collection
|
18 |
|
19 |
+
def store_program(client, code, sequence, vectors, collection_name="python_programs"):
|
20 |
+
# Create or get collection
|
21 |
+
collection = create_collection(client, collection_name)
|
22 |
+
|
23 |
+
# Store program data (ID, code, sequence, vectors)
|
24 |
+
program_id = str(hash(code)) # Use hash of code as ID for uniqueness
|
25 |
+
collection.add(
|
26 |
+
documents=[code],
|
27 |
+
metadatas=[{"sequence": ",".join(sequence)}],
|
28 |
+
ids=[program_id],
|
29 |
+
embeddings=[vectors] # Store vectors as embeddings
|
30 |
+
)
|
31 |
+
return program_id
|
32 |
+
|
33 |
+
def populate_sample_db(client):
|
|
|
|
|
34 |
# Sample programs for testing
|
35 |
samples = [
|
36 |
"""
|
|
|
50 |
for code in samples:
|
51 |
parts, sequence = parse_python_code(code)
|
52 |
vectors = [part['vector'] for part in parts]
|
53 |
+
store_program(client, code, sequence, vectors)
|
54 |
+
|
55 |
+
def query_programs(client, operations, collection_name="python_programs", top_k=5):
|
56 |
+
"""Query the database for programs matching the operations sequence."""
|
57 |
+
collection = create_collection(client, collection_name)
|
58 |
+
|
59 |
+
# Convert operations to a query vector (average of operation vectors)
|
60 |
+
query_vector = sum([create_vector(op, 0, (1, 1), 100, []) for op in operations], []) / len(operations) if operations else [0, 0, 0, 0, 0, 0]
|
61 |
+
|
62 |
+
# Perform similarity search
|
63 |
+
results = collection.query(
|
64 |
+
query_embeddings=[query_vector],
|
65 |
+
n_results=top_k,
|
66 |
+
include=["documents", "metadatas"]
|
67 |
+
)
|
68 |
+
|
69 |
+
# Process results
|
70 |
+
matching_programs = []
|
71 |
+
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
|
72 |
+
sequence = meta['sequence'].split(',')
|
73 |
+
if is_subsequence(operations, sequence):
|
74 |
+
similarity = cosine_similarity([query_vector], [np.mean(eval(doc['vectors']), axis=0) if doc['vectors'] else [0, 0, 0, 0, 0, 0]])[0][0]
|
75 |
+
matching_programs.append({'id': meta['id'], 'code': doc, 'similarity': similarity})
|
76 |
+
|
77 |
+
return sorted(matching_programs, key=lambda x: x['similarity'], reverse=True)
|
78 |
+
|
79 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
80 |
+
import numpy as np
|
81 |
+
|
82 |
+
def create_vector(category, level, location, total_lines, parent_path):
|
83 |
+
"""Helper to create a vector for query (matches parser's create_vector)."""
|
84 |
+
category_map = {
|
85 |
+
'import': 1, 'function': 2, 'async_function': 3, 'class': 4,
|
86 |
+
'if': 5, 'while': 6, 'for': 7, 'try': 8, 'expression': 9, 'spacer': 10,
|
87 |
+
'other': 11, 'elif': 12, 'else': 13, 'except': 14, 'finally': 15, 'return': 16,
|
88 |
+
'assigned_variable': 17, 'input_variable': 18, 'returned_variable': 19
|
89 |
+
}
|
90 |
+
category_id = category_map.get(category, 0)
|
91 |
+
start_line, end_line = location
|
92 |
+
span = (end_line - start_line + 1) / total_lines
|
93 |
+
center_pos = ((start_line + end_line) / 2) / total_lines
|
94 |
+
parent_depth = len(parent_path)
|
95 |
+
parent_weight = sum(category_map.get(parent.split('[')[0].lower(), 0) * (1 / (i + 1))
|
96 |
+
for i, parent in enumerate(parent_path)) / max(1, len(category_map))
|
97 |
+
return [category_id, level, center_pos, span, parent_depth, parent_weight]
|
98 |
+
|
99 |
+
def is_subsequence(subseq, seq):
|
100 |
+
"""Check if subseq is a subsequence of seq."""
|
101 |
+
it = iter(seq)
|
102 |
+
return all(item in it for item in subseq)
|
103 |
|
104 |
if __name__ == '__main__':
|
105 |
+
client = init_chromadb()
|
106 |
+
populate_sample_db(client)
|