Spaces:
Running
Running
Update database.py
Browse files- database.py +21 -13
database.py
CHANGED
@@ -24,7 +24,7 @@ PERSIST_DIR = "./chroma_data" # Directory for persistent storage (optional)
|
|
24 |
USE_GPU = False # Default to CPU, set to True for GPU if available
|
25 |
|
26 |
def init_chromadb(persist_dir=PERSIST_DIR):
|
27 |
-
"""Initialize ChromaDB client, optionally with persistent storage, with error handling."""
|
28 |
try:
|
29 |
# Use persistent storage if directory exists, otherwise in-memory
|
30 |
if os.path.exists(persist_dir):
|
@@ -39,17 +39,19 @@ def init_chromadb(persist_dir=PERSIST_DIR):
|
|
39 |
raise
|
40 |
|
41 |
def create_collection(client, collection_name=DB_NAME):
|
42 |
-
"""Create or get a ChromaDB collection for Python programs, with error handling."""
|
43 |
try:
|
44 |
collection = client.get_or_create_collection(name=collection_name)
|
45 |
-
logger.info(f"Using ChromaDB collection: {collection_name}")
|
|
|
|
|
46 |
return collection
|
47 |
except Exception as e:
|
48 |
logger.error(f"Error creating or getting collection {collection_name}: {e}")
|
49 |
raise
|
50 |
|
51 |
def store_program(client, code, sequence, vectors, collection_name=DB_NAME):
|
52 |
-
"""Store a program in ChromaDB with its code, sequence, and vectors, with error handling."""
|
53 |
try:
|
54 |
collection = create_collection(client, collection_name)
|
55 |
|
@@ -65,7 +67,7 @@ def store_program(client, code, sequence, vectors, collection_name=DB_NAME):
|
|
65 |
ids=[program_id],
|
66 |
embeddings=[flattened_vectors] # Pass as 6D vector
|
67 |
)
|
68 |
-
logger.info(f"Stored program in ChromaDB: {program_id}")
|
69 |
return program_id
|
70 |
except Exception as e:
|
71 |
logger.error(f"Error storing program in ChromaDB: {e}")
|
@@ -93,13 +95,14 @@ def populate_sample_db(client):
|
|
93 |
parts, sequence = parse_python_code(code)
|
94 |
vectors = [part['vector'] for part in parts]
|
95 |
store_program(client, code, sequence, vectors)
|
96 |
-
|
|
|
97 |
except Exception as e:
|
98 |
logger.error(f"Error populating sample database: {e}")
|
99 |
raise
|
100 |
|
101 |
def query_programs(client, operations, collection_name=DB_NAME, top_k=5, semantic_query=None):
|
102 |
-
"""Query ChromaDB for programs matching the operations sequence or semantic description, with error handling."""
|
103 |
try:
|
104 |
collection = create_collection(client, collection_name)
|
105 |
|
@@ -139,7 +142,7 @@ def query_programs(client, operations, collection_name=DB_NAME, top_k=5, semanti
|
|
139 |
similarity = cosine_similarity([query_vector], [semantic_vector])[0][0] if semantic_vector and query_vector else 0
|
140 |
matching_programs.append({'id': meta['id'], 'code': doc, 'similarity': similarity, 'description': meta.get('description_tokens', ''), 'program_vectors': meta.get('program_vectors', '[]')})
|
141 |
|
142 |
-
logger.info(f"Queried {len(matching_programs)} programs from ChromaDB")
|
143 |
return sorted(matching_programs, key=lambda x: x['similarity'], reverse=True)
|
144 |
except Exception as e:
|
145 |
logger.error(f"Error querying programs from ChromaDB: {e}")
|
@@ -238,7 +241,7 @@ def generate_semantic_vector(description, total_lines=100, use_gpu=False):
|
|
238 |
return vector
|
239 |
|
240 |
def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
|
241 |
-
"""Save ChromaDB data to Hugging Face Dataset, with error handling."""
|
242 |
try:
|
243 |
client = init_chromadb()
|
244 |
collection = client.get_collection(DB_NAME)
|
@@ -255,16 +258,19 @@ def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY"))
|
|
255 |
|
256 |
# Create a Hugging Face Dataset
|
257 |
dataset = Dataset.from_dict(data)
|
|
|
258 |
|
259 |
# Push to Hugging Face Hub
|
260 |
dataset.push_to_hub(dataset_name, token=token)
|
261 |
logger.info(f"Dataset pushed to Hugging Face Hub as {dataset_name}")
|
|
|
|
|
262 |
except Exception as e:
|
263 |
logger.error(f"Error pushing dataset to Hugging Face Hub: {e}")
|
264 |
raise
|
265 |
|
266 |
def load_chromadb_from_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
|
267 |
-
"""Load ChromaDB data from Hugging Face Dataset, handle empty dataset, with error handling."""
|
268 |
try:
|
269 |
dataset = load_dataset(dataset_name, split="train", token=token)
|
270 |
client = init_chromadb()
|
@@ -272,15 +278,17 @@ def load_chromadb_from_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY"
|
|
272 |
|
273 |
for item in dataset:
|
274 |
store_program(client, item["code"], item["sequence"].split(','), item["program_vectors"])
|
275 |
-
|
|
|
276 |
return client
|
277 |
except Exception as e:
|
278 |
logger.error(f"Error loading dataset from Hugging Face: {e}")
|
279 |
# Fallback: Create empty collection
|
280 |
client = init_chromadb()
|
281 |
-
create_collection(client)
|
|
|
282 |
return client
|
283 |
|
284 |
if __name__ == '__main__':
|
285 |
client = load_chromadb_from_hf()
|
286 |
-
logger.info("Database initialized or loaded from Hugging Face Hub")
|
|
|
24 |
USE_GPU = False # Default to CPU, set to True for GPU if available
|
25 |
|
26 |
def init_chromadb(persist_dir=PERSIST_DIR):
|
27 |
+
"""Initialize ChromaDB client, optionally with persistent storage, with error handling and logging."""
|
28 |
try:
|
29 |
# Use persistent storage if directory exists, otherwise in-memory
|
30 |
if os.path.exists(persist_dir):
|
|
|
39 |
raise
|
40 |
|
41 |
def create_collection(client, collection_name=DB_NAME):
|
42 |
+
"""Create or get a ChromaDB collection for Python programs, with error handling and logging."""
|
43 |
try:
|
44 |
collection = client.get_or_create_collection(name=collection_name)
|
45 |
+
logger.info(f"Using ChromaDB collection: {collection_name}, contains {collection.count()} entries")
|
46 |
+
if collection is None or not hasattr(collection, 'add'):
|
47 |
+
raise ValueError("ChromaDB collection creation or access failed")
|
48 |
return collection
|
49 |
except Exception as e:
|
50 |
logger.error(f"Error creating or getting collection {collection_name}: {e}")
|
51 |
raise
|
52 |
|
53 |
def store_program(client, code, sequence, vectors, collection_name=DB_NAME):
|
54 |
+
"""Store a program in ChromaDB with its code, sequence, and vectors, with error handling and logging."""
|
55 |
try:
|
56 |
collection = create_collection(client, collection_name)
|
57 |
|
|
|
67 |
ids=[program_id],
|
68 |
embeddings=[flattened_vectors] # Pass as 6D vector
|
69 |
)
|
70 |
+
logger.info(f"Stored program in ChromaDB: {program_id}, total entries: {collection.count()}")
|
71 |
return program_id
|
72 |
except Exception as e:
|
73 |
logger.error(f"Error storing program in ChromaDB: {e}")
|
|
|
95 |
parts, sequence = parse_python_code(code)
|
96 |
vectors = [part['vector'] for part in parts]
|
97 |
store_program(client, code, sequence, vectors)
|
98 |
+
collection = create_collection(client, DB_NAME)
|
99 |
+
logger.info(f"Populated ChromaDB with sample programs, total entries: {collection.count()}")
|
100 |
except Exception as e:
|
101 |
logger.error(f"Error populating sample database: {e}")
|
102 |
raise
|
103 |
|
104 |
def query_programs(client, operations, collection_name=DB_NAME, top_k=5, semantic_query=None):
|
105 |
+
"""Query ChromaDB for programs matching the operations sequence or semantic description, with error handling and logging."""
|
106 |
try:
|
107 |
collection = create_collection(client, collection_name)
|
108 |
|
|
|
142 |
similarity = cosine_similarity([query_vector], [semantic_vector])[0][0] if semantic_vector and query_vector else 0
|
143 |
matching_programs.append({'id': meta['id'], 'code': doc, 'similarity': similarity, 'description': meta.get('description_tokens', ''), 'program_vectors': meta.get('program_vectors', '[]')})
|
144 |
|
145 |
+
logger.info(f"Queried {len(matching_programs)} programs from ChromaDB, total entries: {collection.count()}")
|
146 |
return sorted(matching_programs, key=lambda x: x['similarity'], reverse=True)
|
147 |
except Exception as e:
|
148 |
logger.error(f"Error querying programs from ChromaDB: {e}")
|
|
|
241 |
return vector
|
242 |
|
243 |
def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
|
244 |
+
"""Save ChromaDB data to Hugging Face Dataset, with error handling and logging."""
|
245 |
try:
|
246 |
client = init_chromadb()
|
247 |
collection = client.get_collection(DB_NAME)
|
|
|
258 |
|
259 |
# Create a Hugging Face Dataset
|
260 |
dataset = Dataset.from_dict(data)
|
261 |
+
logger.info(f"Created Hugging Face Dataset with {len(data['code'])} entries")
|
262 |
|
263 |
# Push to Hugging Face Hub
|
264 |
dataset.push_to_hub(dataset_name, token=token)
|
265 |
logger.info(f"Dataset pushed to Hugging Face Hub as {dataset_name}")
|
266 |
+
# Verify push (optional, could check dataset on Hub)
|
267 |
+
logger.info(f"Verified Hugging Face dataset push with {len(dataset)} entries")
|
268 |
except Exception as e:
|
269 |
logger.error(f"Error pushing dataset to Hugging Face Hub: {e}")
|
270 |
raise
|
271 |
|
272 |
def load_chromadb_from_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
|
273 |
+
"""Load ChromaDB data from Hugging Face Dataset, handle empty dataset, with error handling and logging."""
|
274 |
try:
|
275 |
dataset = load_dataset(dataset_name, split="train", token=token)
|
276 |
client = init_chromadb()
|
|
|
278 |
|
279 |
for item in dataset:
|
280 |
store_program(client, item["code"], item["sequence"].split(','), item["program_vectors"])
|
281 |
+
collection = create_collection(client, DB_NAME)
|
282 |
+
logger.info(f"Loaded {len(dataset)} entries from Hugging Face Hub into ChromaDB, total entries: {collection.count()}")
|
283 |
return client
|
284 |
except Exception as e:
|
285 |
logger.error(f"Error loading dataset from Hugging Face: {e}")
|
286 |
# Fallback: Create empty collection
|
287 |
client = init_chromadb()
|
288 |
+
collection = create_collection(client)
|
289 |
+
logger.info(f"Created empty ChromaDB collection: {DB_NAME}, contains {collection.count()} entries")
|
290 |
return client
|
291 |
|
292 |
if __name__ == '__main__':
|
293 |
client = load_chromadb_from_hf()
|
294 |
+
logger.info("Database initialized or loaded from Hugging Face Hub, contains {client.get_collection(DB_NAME).count()} entries")
|