Spaces:
No application file
No application file
from flask import Flask, request, jsonify | |
import chromadb | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Initialize Chroma DB client | |
client = chromadb.Client() | |
# Function to get collection by name or create if it doesn't exist | |
def get_or_create_collection(collection_name: str): | |
if collection_name in client.list_collections(): | |
return client.get_collection(collection_name) | |
else: | |
return client.create_collection(collection_name) | |
def create_collection(): | |
data = request.json | |
collection_name = data.get("collection_name") | |
if not collection_name: | |
return jsonify({"error": "Collection name is required"}), 400 | |
collection = get_or_create_collection(collection_name) | |
return jsonify({"status": f"Collection '{collection_name}' is ready."}), 200 | |
def add_embedding(): | |
data = request.json | |
collection_name = data.get("collection_name") | |
document = data.get("document") | |
embedding = data.get("embedding") | |
if not collection_name or not document or not embedding: | |
return jsonify({"error": "Collection name, document, and embedding are required"}), 400 | |
# Get or create the collection | |
collection = get_or_create_collection(collection_name) | |
# Add the document and its embedding to Chroma DB | |
collection.add([document], [{"source": document}], [embedding]) | |
return jsonify({"status": "success"}), 200 | |
def query_embedding(): | |
data = request.json | |
collection_name = data.get("collection_name") | |
query_embedding = data.get("embedding") | |
n_results = data.get("n_results", 5) | |
if not collection_name or not query_embedding: | |
return jsonify({"error": "Collection name and embedding are required"}), 400 | |
# Get or create the collection | |
collection = get_or_create_collection(collection_name) | |
# Query the collection for nearest neighbors | |
results = collection.query(query_embeddings=[query_embedding], n_results=n_results) | |
return jsonify({"results": results["documents"]}), 200 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860 | |