File size: 4,044 Bytes
2017cb6
dda378f
2017cb6
dda378f
 
 
 
 
 
 
 
 
 
 
 
 
 
2017cb6
dda378f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2017cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dda378f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2017cb6
 
dda378f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# database.py
import chromadb
from parser import parse_python_code
import os

def init_chromadb():
    # Initialize ChromaDB client (in-memory for now, can persist to disk)
    client = chromadb.Client()
    return client

def create_collection(client, collection_name="python_programs"):
    # Create or get a collection for Python programs
    try:
        collection = client.get_collection(name=collection_name)
    except:
        collection = client.create_collection(name=collection_name)
    return collection

def store_program(client, code, sequence, vectors, collection_name="python_programs"):
    # Create or get collection
    collection = create_collection(client, collection_name)
    
    # Store program data (ID, code, sequence, vectors)
    program_id = str(hash(code))  # Use hash of code as ID for uniqueness
    collection.add(
        documents=[code],
        metadatas=[{"sequence": ",".join(sequence)}],
        ids=[program_id],
        embeddings=[vectors]  # Store vectors as embeddings
    )
    return program_id

def populate_sample_db(client):
    # Sample programs for testing
    samples = [
        """
        import os
        def add_one(x):
            y = x + 1
            return y
        """,
        """
        def multiply(a, b):
            c = a * b
            if c > 0:
                return c
        """
    ]
    
    for code in samples:
        parts, sequence = parse_python_code(code)
        vectors = [part['vector'] for part in parts]
        store_program(client, code, sequence, vectors)

def query_programs(client, operations, collection_name="python_programs", top_k=5):
    """Query the database for programs matching the operations sequence."""
    collection = create_collection(client, collection_name)
    
    # Convert operations to a query vector (average of operation vectors)
    query_vector = sum([create_vector(op, 0, (1, 1), 100, []) for op in operations], []) / len(operations) if operations else [0, 0, 0, 0, 0, 0]
    
    # Perform similarity search
    results = collection.query(
        query_embeddings=[query_vector],
        n_results=top_k,
        include=["documents", "metadatas"]
    )
    
    # Process results
    matching_programs = []
    for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
        sequence = meta['sequence'].split(',')
        if is_subsequence(operations, sequence):
            similarity = cosine_similarity([query_vector], [np.mean(eval(doc['vectors']), axis=0) if doc['vectors'] else [0, 0, 0, 0, 0, 0]])[0][0]
            matching_programs.append({'id': meta['id'], 'code': doc, 'similarity': similarity})
    
    return sorted(matching_programs, key=lambda x: x['similarity'], reverse=True)

from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def create_vector(category, level, location, total_lines, parent_path):
    """Helper to create a vector for query (matches parser's create_vector)."""
    category_map = {
        'import': 1, 'function': 2, 'async_function': 3, 'class': 4,
        'if': 5, 'while': 6, 'for': 7, 'try': 8, 'expression': 9, 'spacer': 10,
        'other': 11, 'elif': 12, 'else': 13, 'except': 14, 'finally': 15, 'return': 16,
        'assigned_variable': 17, 'input_variable': 18, 'returned_variable': 19
    }
    category_id = category_map.get(category, 0)
    start_line, end_line = location
    span = (end_line - start_line + 1) / total_lines
    center_pos = ((start_line + end_line) / 2) / total_lines
    parent_depth = len(parent_path)
    parent_weight = sum(category_map.get(parent.split('[')[0].lower(), 0) * (1 / (i + 1)) 
                        for i, parent in enumerate(parent_path)) / max(1, len(category_map))
    return [category_id, level, center_pos, span, parent_depth, parent_weight]

def is_subsequence(subseq, seq):
    """Check if subseq is a subsequence of seq."""
    it = iter(seq)
    return all(item in it for item in subseq)

if __name__ == '__main__':
    client = init_chromadb()
    populate_sample_db(client)