broadfield-dev commited on
Commit
dda378f
·
verified ·
1 Parent(s): 79e0cb5

Update database.py

Browse files
Files changed (1) hide show
  1. database.py +82 -22
database.py CHANGED
@@ -1,25 +1,36 @@
1
  # database.py
2
- import sqlite3
3
- import os
4
  from parser import parse_python_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def init_db():
7
- conn = sqlite3.connect('python_programs.db')
8
- c = conn.cursor()
9
- c.execute('''CREATE TABLE IF NOT EXISTS programs
10
- (id INTEGER PRIMARY KEY, code TEXT, sequence TEXT, vectors TEXT)''')
11
- conn.commit()
12
- conn.close()
13
-
14
- def store_program(code, sequence, vectors):
15
- conn = sqlite3.connect('python_programs.db')
16
- c = conn.cursor()
17
- c.execute("INSERT INTO programs (code, sequence, vectors) VALUES (?, ?, ?)",
18
- (code, ','.join(sequence), str(vectors)))
19
- conn.commit()
20
- conn.close()
21
-
22
- def populate_sample_db():
23
  # Sample programs for testing
24
  samples = [
25
  """
@@ -39,8 +50,57 @@ def populate_sample_db():
39
  for code in samples:
40
  parts, sequence = parse_python_code(code)
41
  vectors = [part['vector'] for part in parts]
42
- store_program(code, sequence, vectors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  if __name__ == '__main__':
45
- init_db()
46
- populate_sample_db()
 
1
  # database.py
2
+ import chromadb
 
3
  from parser import parse_python_code
4
+ import os
5
+
6
+ def init_chromadb():
7
+ # Initialize ChromaDB client (in-memory for now, can persist to disk)
8
+ client = chromadb.Client()
9
+ return client
10
+
11
+ def create_collection(client, collection_name="python_programs"):
12
+ # Create or get a collection for Python programs
13
+ try:
14
+ collection = client.get_collection(name=collection_name)
15
+ except:
16
+ collection = client.create_collection(name=collection_name)
17
+ return collection
18
 
19
+ def store_program(client, code, sequence, vectors, collection_name="python_programs"):
20
+ # Create or get collection
21
+ collection = create_collection(client, collection_name)
22
+
23
+ # Store program data (ID, code, sequence, vectors)
24
+ program_id = str(hash(code)) # Use hash of code as ID for uniqueness
25
+ collection.add(
26
+ documents=[code],
27
+ metadatas=[{"sequence": ",".join(sequence)}],
28
+ ids=[program_id],
29
+ embeddings=[vectors] # Store vectors as embeddings
30
+ )
31
+ return program_id
32
+
33
+ def populate_sample_db(client):
 
 
34
  # Sample programs for testing
35
  samples = [
36
  """
 
50
  for code in samples:
51
  parts, sequence = parse_python_code(code)
52
  vectors = [part['vector'] for part in parts]
53
+ store_program(client, code, sequence, vectors)
54
+
55
+ def query_programs(client, operations, collection_name="python_programs", top_k=5):
56
+ """Query the database for programs matching the operations sequence."""
57
+ collection = create_collection(client, collection_name)
58
+
59
+ # Convert operations to a query vector (average of operation vectors)
60
+ query_vector = sum([create_vector(op, 0, (1, 1), 100, []) for op in operations], []) / len(operations) if operations else [0, 0, 0, 0, 0, 0]
61
+
62
+ # Perform similarity search
63
+ results = collection.query(
64
+ query_embeddings=[query_vector],
65
+ n_results=top_k,
66
+ include=["documents", "metadatas"]
67
+ )
68
+
69
+ # Process results
70
+ matching_programs = []
71
+ for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
72
+ sequence = meta['sequence'].split(',')
73
+ if is_subsequence(operations, sequence):
74
+ similarity = cosine_similarity([query_vector], [np.mean(eval(doc['vectors']), axis=0) if doc['vectors'] else [0, 0, 0, 0, 0, 0]])[0][0]
75
+ matching_programs.append({'id': meta['id'], 'code': doc, 'similarity': similarity})
76
+
77
+ return sorted(matching_programs, key=lambda x: x['similarity'], reverse=True)
78
+
79
+ from sklearn.metrics.pairwise import cosine_similarity
80
+ import numpy as np
81
+
82
+ def create_vector(category, level, location, total_lines, parent_path):
83
+ """Helper to create a vector for query (matches parser's create_vector)."""
84
+ category_map = {
85
+ 'import': 1, 'function': 2, 'async_function': 3, 'class': 4,
86
+ 'if': 5, 'while': 6, 'for': 7, 'try': 8, 'expression': 9, 'spacer': 10,
87
+ 'other': 11, 'elif': 12, 'else': 13, 'except': 14, 'finally': 15, 'return': 16,
88
+ 'assigned_variable': 17, 'input_variable': 18, 'returned_variable': 19
89
+ }
90
+ category_id = category_map.get(category, 0)
91
+ start_line, end_line = location
92
+ span = (end_line - start_line + 1) / total_lines
93
+ center_pos = ((start_line + end_line) / 2) / total_lines
94
+ parent_depth = len(parent_path)
95
+ parent_weight = sum(category_map.get(parent.split('[')[0].lower(), 0) * (1 / (i + 1))
96
+ for i, parent in enumerate(parent_path)) / max(1, len(category_map))
97
+ return [category_id, level, center_pos, span, parent_depth, parent_weight]
98
+
99
+ def is_subsequence(subseq, seq):
100
+ """Check if subseq is a subsequence of seq."""
101
+ it = iter(seq)
102
+ return all(item in it for item in subseq)
103
 
104
  if __name__ == '__main__':
105
+ client = init_chromadb()
106
+ populate_sample_db(client)