|
from fastapi import APIRouter, Depends, HTTPException, Request |
|
from pydantic import BaseModel |
|
import logging |
|
from typing import Optional |
|
|
|
from open_webui.apps.webui.models.memories import Memories, MemoryModel |
|
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT |
|
from open_webui.utils.utils import get_verified_user |
|
from open_webui.env import SRC_LOG_LEVELS |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["MODELS"]) |
|
|
|
router = APIRouter() |
|
|
|
|
|
@router.get("/ef") |
|
async def get_embeddings(request: Request): |
|
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/", response_model=list[MemoryModel]) |
|
async def get_memories(user=Depends(get_verified_user)): |
|
return Memories.get_memories_by_user_id(user.id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AddMemoryForm(BaseModel): |
|
content: str |
|
|
|
|
|
class MemoryUpdateModel(BaseModel): |
|
content: Optional[str] = None |
|
|
|
|
|
@router.post("/add", response_model=Optional[MemoryModel]) |
|
async def add_memory( |
|
request: Request, |
|
form_data: AddMemoryForm, |
|
user=Depends(get_verified_user), |
|
): |
|
memory = Memories.insert_new_memory(user.id, form_data.content) |
|
|
|
VECTOR_DB_CLIENT.upsert( |
|
collection_name=f"user-memory-{user.id}", |
|
items=[ |
|
{ |
|
"id": memory.id, |
|
"text": memory.content, |
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content), |
|
"metadata": {"created_at": memory.created_at}, |
|
} |
|
], |
|
) |
|
|
|
return memory |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QueryMemoryForm(BaseModel): |
|
content: str |
|
k: Optional[int] = 1 |
|
|
|
|
|
@router.post("/query") |
|
async def query_memory( |
|
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) |
|
): |
|
results = VECTOR_DB_CLIENT.search( |
|
collection_name=f"user-memory-{user.id}", |
|
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)], |
|
limit=form_data.k, |
|
) |
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
@router.post("/reset", response_model=bool) |
|
async def reset_memory_from_vector_db( |
|
request: Request, user=Depends(get_verified_user) |
|
): |
|
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") |
|
|
|
memories = Memories.get_memories_by_user_id(user.id) |
|
VECTOR_DB_CLIENT.upsert( |
|
collection_name=f"user-memory-{user.id}", |
|
items=[ |
|
{ |
|
"id": memory.id, |
|
"text": memory.content, |
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content), |
|
"metadata": { |
|
"created_at": memory.created_at, |
|
"updated_at": memory.updated_at, |
|
}, |
|
} |
|
for memory in memories |
|
], |
|
) |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/delete/user", response_model=bool) |
|
async def delete_memory_by_user_id(user=Depends(get_verified_user)): |
|
result = Memories.delete_memories_by_user_id(user.id) |
|
|
|
if result: |
|
try: |
|
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") |
|
except Exception as e: |
|
log.error(e) |
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/{memory_id}/update", response_model=Optional[MemoryModel]) |
|
async def update_memory_by_id( |
|
memory_id: str, |
|
request: Request, |
|
form_data: MemoryUpdateModel, |
|
user=Depends(get_verified_user), |
|
): |
|
memory = Memories.update_memory_by_id(memory_id, form_data.content) |
|
if memory is None: |
|
raise HTTPException(status_code=404, detail="Memory not found") |
|
|
|
if form_data.content is not None: |
|
VECTOR_DB_CLIENT.upsert( |
|
collection_name=f"user-memory-{user.id}", |
|
items=[ |
|
{ |
|
"id": memory.id, |
|
"text": memory.content, |
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content), |
|
"metadata": { |
|
"created_at": memory.created_at, |
|
"updated_at": memory.updated_at, |
|
}, |
|
} |
|
], |
|
) |
|
|
|
return memory |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/{memory_id}", response_model=bool) |
|
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): |
|
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) |
|
|
|
if result: |
|
VECTOR_DB_CLIENT.delete( |
|
collection_name=f"user-memory-{user.id}", ids=[memory_id] |
|
) |
|
return True |
|
|
|
return False |
|
|