chromadb / app.py
srivatsavdamaraju's picture
Update app.py
6b30c1d verified
raw
history blame
2.24 kB
from flask import Flask, request, jsonify
import chromadb
# Initialize Flask app
app = Flask(__name__)
# Initialize Chroma DB client
client = chromadb.Client()
# Function to get collection by name or create if it doesn't exist
def get_or_create_collection(collection_name: str):
if collection_name in client.list_collections():
return client.get_collection(collection_name)
else:
return client.create_collection(collection_name)
@app.route("/create_collection/", methods=["POST"])
def create_collection():
data = request.json
collection_name = data.get("collection_name")
if not collection_name:
return jsonify({"error": "Collection name is required"}), 400
collection = get_or_create_collection(collection_name)
return jsonify({"status": f"Collection '{collection_name}' is ready."}), 200
@app.route("/add_embedding/", methods=["POST"])
def add_embedding():
data = request.json
collection_name = data.get("collection_name")
document = data.get("document")
embedding = data.get("embedding")
if not collection_name or not document or not embedding:
return jsonify({"error": "Collection name, document, and embedding are required"}), 400
# Get or create the collection
collection = get_or_create_collection(collection_name)
# Add the document and its embedding to Chroma DB
collection.add([document], [{"source": document}], [embedding])
return jsonify({"status": "success"}), 200
@app.route("/query_embedding/", methods=["POST"])
def query_embedding():
data = request.json
collection_name = data.get("collection_name")
query_embedding = data.get("embedding")
n_results = data.get("n_results", 5)
if not collection_name or not query_embedding:
return jsonify({"error": "Collection name and embedding are required"}), 400
# Get or create the collection
collection = get_or_create_collection(collection_name)
# Query the collection for nearest neighbors
results = collection.query(query_embeddings=[query_embedding], n_results=n_results)
return jsonify({"results": results["documents"]}), 200
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860