Spaces:
Running
Running
Commit
·
640b1c8
0
Parent(s):
initial commit
Browse files- .vscode/launch.json +30 -0
- .vscode/settings.json +12 -0
- DockerComposeConfiguration +33 -0
- Dockerfile +25 -0
- Readme.md +80 -0
- config/__init__.py +0 -0
- config/config.py +29 -0
- requirements.txt +23 -0
- src/__init__.py +0 -0
- src/agents/__init__.py +0 -0
- src/agents/rag_agent.py +106 -0
- src/embeddings/__init__.py +0 -0
- src/embeddings/base_embedding.py +30 -0
- src/embeddings/huggingface_embedding.py +39 -0
- src/llms/__init__.py +0 -0
- src/llms/base_llm.py +51 -0
- src/llms/ollama_llm.py +80 -0
- src/llms/openai_llm.py +76 -0
- src/main.py +66 -0
- src/utils/__init__.py +0 -0
- src/utils/document_loader.py +91 -0
- src/utils/logger.py +83 -0
- src/utils/text_splitter.py +52 -0
- src/vctorstores/__init__.py +0 -0
- src/vctorstores/base_vectorstore.py +37 -0
- src/vctorstores/chroma_vectorstore.py +68 -0
.vscode/launch.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "0.2.0",
|
3 |
+
"configurations": [
|
4 |
+
{
|
5 |
+
"name": "Python: FastAPI",
|
6 |
+
"type": "python",
|
7 |
+
"request": "launch",
|
8 |
+
"module": "uvicorn",
|
9 |
+
"args": [
|
10 |
+
"src.main:app",
|
11 |
+
"--reload"
|
12 |
+
],
|
13 |
+
"jinja": true,
|
14 |
+
"justMyCode": true,
|
15 |
+
"env": {
|
16 |
+
"PYTHONPATH": "${workspaceFolder}"
|
17 |
+
}
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"name": "Python: Test",
|
21 |
+
"type": "python",
|
22 |
+
"request": "launch",
|
23 |
+
"module": "pytest",
|
24 |
+
"args": [
|
25 |
+
"tests"
|
26 |
+
],
|
27 |
+
"console": "integratedTerminal"
|
28 |
+
}
|
29 |
+
]
|
30 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.pythonPath": "${workspaceFolder}/venv/bin/python",
|
3 |
+
"python.linting.enabled": true,
|
4 |
+
"python.linting.pylintEnabled": true,
|
5 |
+
"python.formatting.provider": "black",
|
6 |
+
"editor.formatOnSave": true,
|
7 |
+
"python.testing.pytestArgs": [
|
8 |
+
"tests"
|
9 |
+
],
|
10 |
+
"python.testing.unittestEnabled": false,
|
11 |
+
"python.testing.pytestEnabled": true
|
12 |
+
}
|
DockerComposeConfiguration
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
app:
|
5 |
+
build: .
|
6 |
+
ports:
|
7 |
+
- "8000:8000"
|
8 |
+
env_file:
|
9 |
+
- .env
|
10 |
+
volumes:
|
11 |
+
- ./:/app
|
12 |
+
depends_on:
|
13 |
+
- ollama
|
14 |
+
|
15 |
+
ollama:
|
16 |
+
image: ollama/ollama
|
17 |
+
ports:
|
18 |
+
- "11434:11434"
|
19 |
+
volumes:
|
20 |
+
- ollama-data:/root/.ollama
|
21 |
+
|
22 |
+
chroma:
|
23 |
+
image: chromadb/chroma
|
24 |
+
ports:
|
25 |
+
- "8000:8000"
|
26 |
+
volumes:
|
27 |
+
- chroma-data:/chroma
|
28 |
+
environment:
|
29 |
+
- PERSIST_DIRECTORY=/chroma
|
30 |
+
|
31 |
+
volumes:
|
32 |
+
ollama-data:
|
33 |
+
chroma-data:
|
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install system dependencies
|
8 |
+
RUN apt-get update && apt-get install -y \
|
9 |
+
build-essential \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
# Copy the current directory contents into the container at /app
|
13 |
+
COPY . /app
|
14 |
+
|
15 |
+
# Install any needed packages specified in requirements.txt
|
16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
17 |
+
|
18 |
+
# Make port 8000 available to the world outside this container
|
19 |
+
EXPOSE 8000
|
20 |
+
|
21 |
+
# Define environment variable
|
22 |
+
ENV NAME RAGChatbot
|
23 |
+
|
24 |
+
# Run the application
|
25 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
Readme.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG Chatbot Application
|
2 |
+
|
3 |
+
## Project Overview
|
4 |
+
A modular Retrieval Augmented Generation (RAG) chatbot application built with FastAPI, supporting multiple LLM providers and embedding models.
|
5 |
+
|
6 |
+
## Project Structure
|
7 |
+
- `config/`: Configuration management
|
8 |
+
- `src/`: Main application source code
|
9 |
+
- `tests/`: Unit and integration tests
|
10 |
+
- `data/`: Document storage and ingestion
|
11 |
+
|
12 |
+
## Prerequisites
|
13 |
+
- Python 3.9+
|
14 |
+
- pip
|
15 |
+
- (Optional) Virtual environment
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
|
19 |
+
1. Clone the repository
|
20 |
+
```bash
|
21 |
+
git clone https://your-repo-url.git
|
22 |
+
cd rag-chatbot
|
23 |
+
```
|
24 |
+
|
25 |
+
2. Create a virtual environment
|
26 |
+
```bash
|
27 |
+
python -m venv venv
|
28 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
29 |
+
```
|
30 |
+
|
31 |
+
3. Install dependencies
|
32 |
+
```bash
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
35 |
+
|
36 |
+
4. Set up environment variables
|
37 |
+
```bash
|
38 |
+
cp .env.example .env
|
39 |
+
# Edit .env with your credentials
|
40 |
+
```
|
41 |
+
|
42 |
+
## Configuration
|
43 |
+
|
44 |
+
### Environment Variables
|
45 |
+
- `OPENAI_API_KEY`: OpenAI API key
|
46 |
+
- `OLLAMA_BASE_URL`: Ollama server URL
|
47 |
+
- `EMBEDDING_MODEL`: Hugging Face embedding model
|
48 |
+
- `CHROMA_PATH`: Vector store persistence path
|
49 |
+
- `DEBUG`: Enable debug mode
|
50 |
+
|
51 |
+
## Running the Application
|
52 |
+
|
53 |
+
### Development Server
|
54 |
+
```bash
|
55 |
+
uvicorn src.main:app --reload
|
56 |
+
```
|
57 |
+
|
58 |
+
### Production Deployment
|
59 |
+
```bash
|
60 |
+
gunicorn -w 4 -k uvicorn.workers.UvicornWorker src.main:app
|
61 |
+
```
|
62 |
+
|
63 |
+
## Testing
|
64 |
+
```bash
|
65 |
+
pytest tests/
|
66 |
+
```
|
67 |
+
|
68 |
+
## Features
|
69 |
+
- Multiple LLM Provider Support
|
70 |
+
- Retrieval Augmented Generation
|
71 |
+
- Document Ingestion
|
72 |
+
- Flexible Configuration
|
73 |
+
- FastAPI Backend
|
74 |
+
|
75 |
+
## Contributing
|
76 |
+
1. Fork the repository
|
77 |
+
2. Create your feature branch
|
78 |
+
3. Commit your changes
|
79 |
+
4. Push to the branch
|
80 |
+
5. Create a Pull Request
|
config/__init__.py
ADDED
File without changes
|
config/config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config/config.py
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
# Load environment variables
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
class Settings:
|
9 |
+
# OpenAI Configuration
|
10 |
+
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '')
|
11 |
+
OPENAI_MODEL = os.getenv('OPENAI_MODEL', 'gpt-3.5-turbo')
|
12 |
+
|
13 |
+
# Ollama Configuration
|
14 |
+
OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
15 |
+
OLLAMA_MODEL = os.getenv('OLLAMA_MODEL', 'llama2')
|
16 |
+
|
17 |
+
# Anthropic Configuration
|
18 |
+
ANTHROPIC_API_KEY = os.getenv('ANTHROPIC_API_KEY', '')
|
19 |
+
|
20 |
+
# Embedding Configuration
|
21 |
+
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
|
22 |
+
|
23 |
+
# Vector Store Configuration
|
24 |
+
CHROMA_PATH = os.getenv('CHROMA_PATH', './chroma_db')
|
25 |
+
|
26 |
+
# Application Configuration
|
27 |
+
DEBUG = os.getenv('DEBUG', 'False') == 'True'
|
28 |
+
|
29 |
+
settings = Settings()
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Requirements for RAG Chatbot
|
2 |
+
fastapi==0.109.0
|
3 |
+
uvicorn==0.24.0
|
4 |
+
pydantic==2.6.1
|
5 |
+
python-dotenv==1.0.0
|
6 |
+
|
7 |
+
# LLM Providers
|
8 |
+
openai==1.12.0
|
9 |
+
anthropic==0.18.0
|
10 |
+
ollama==0.1.6
|
11 |
+
|
12 |
+
# Embedding and Vector Store
|
13 |
+
sentence-transformers==2.3.1
|
14 |
+
chromadb==0.4.22
|
15 |
+
huggingface_hub==0.20.3
|
16 |
+
|
17 |
+
# Optional: Additional dependencies
|
18 |
+
numpy==1.26.3
|
19 |
+
torch==2.1.2
|
20 |
+
|
21 |
+
PyPDF2==3.0.1
|
22 |
+
python-docx==1.0.1
|
23 |
+
requests==2.31.0
|
src/__init__.py
ADDED
File without changes
|
src/agents/__init__.py
ADDED
File without changes
|
src/agents/rag_agent.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/agents/rag_agent.py
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
from ..llms.base_llm import BaseLLM
|
6 |
+
from ..embeddings import BaseEmbedding
|
7 |
+
from ..vectorstores.base_vectorstore import BaseVectorStore
|
8 |
+
from ..utils import split_text
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class RAGResponse:
|
12 |
+
response: str
|
13 |
+
context_docs: Optional[List[str]] = None
|
14 |
+
|
15 |
+
class RAGAgent:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
llm: BaseLLM,
|
19 |
+
embedding: BaseEmbedding,
|
20 |
+
vector_store: BaseVectorStore
|
21 |
+
):
|
22 |
+
self.llm = llm
|
23 |
+
self.embedding = embedding
|
24 |
+
self.vector_store = vector_store
|
25 |
+
|
26 |
+
def retrieve_context(
|
27 |
+
self,
|
28 |
+
query: str,
|
29 |
+
top_k: int = 3
|
30 |
+
) -> List[str]:
|
31 |
+
"""
|
32 |
+
Retrieve relevant context documents for a given query
|
33 |
+
|
34 |
+
Args:
|
35 |
+
query (str): Input query to find context for
|
36 |
+
top_k (int): Number of top context documents to retrieve
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
List[str]: List of retrieved context documents
|
40 |
+
"""
|
41 |
+
# Embed the query
|
42 |
+
query_embedding = self.embedding.embed_query(query)
|
43 |
+
|
44 |
+
# Retrieve similar documents
|
45 |
+
context_docs = self.vector_store.similarity_search(
|
46 |
+
query_embedding,
|
47 |
+
top_k=top_k
|
48 |
+
)
|
49 |
+
|
50 |
+
return context_docs
|
51 |
+
|
52 |
+
def generate_response(
|
53 |
+
self,
|
54 |
+
query: str,
|
55 |
+
context_docs: Optional[List[str]] = None
|
56 |
+
) -> RAGResponse:
|
57 |
+
"""
|
58 |
+
Generate a response using RAG approach
|
59 |
+
|
60 |
+
Args:
|
61 |
+
query (str): User input query
|
62 |
+
context_docs (Optional[List[str]]): Optional pre-provided context documents
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
RAGResponse: Response with generated text and context
|
66 |
+
"""
|
67 |
+
# If no context provided, retrieve from vector store
|
68 |
+
if not context_docs:
|
69 |
+
context_docs = self.retrieve_context(query)
|
70 |
+
|
71 |
+
# Construct augmented prompt with context
|
72 |
+
augmented_prompt = self._construct_prompt(query, context_docs)
|
73 |
+
|
74 |
+
# Generate response using LLM
|
75 |
+
response = self.llm.generate(augmented_prompt)
|
76 |
+
|
77 |
+
return RAGResponse(
|
78 |
+
response=response,
|
79 |
+
context_docs=context_docs
|
80 |
+
)
|
81 |
+
|
82 |
+
def _construct_prompt(
|
83 |
+
self,
|
84 |
+
query: str,
|
85 |
+
context_docs: List[str]
|
86 |
+
) -> str:
|
87 |
+
"""
|
88 |
+
Construct a prompt with retrieved context
|
89 |
+
|
90 |
+
Args:
|
91 |
+
query (str): Original user query
|
92 |
+
context_docs (List[str]): Retrieved context documents
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
str: Augmented prompt for the LLM
|
96 |
+
"""
|
97 |
+
context_str = "\n\n".join(context_docs)
|
98 |
+
|
99 |
+
return f"""
|
100 |
+
Context Information:
|
101 |
+
{context_str}
|
102 |
+
|
103 |
+
User Query: {query}
|
104 |
+
|
105 |
+
Based on the context, please provide a comprehensive and accurate response.
|
106 |
+
"""
|
src/embeddings/__init__.py
ADDED
File without changes
|
src/embeddings/base_embedding.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/embeddings/base_embedding.py
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
class BaseEmbedding(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
8 |
+
"""
|
9 |
+
Embed a list of documents
|
10 |
+
|
11 |
+
Args:
|
12 |
+
texts (List[str]): List of texts to embed
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
List[List[float]]: List of embeddings
|
16 |
+
"""
|
17 |
+
pass
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def embed_query(self, text: str) -> List[float]:
|
21 |
+
"""
|
22 |
+
Embed a single query
|
23 |
+
|
24 |
+
Args:
|
25 |
+
text (str): Text to embed
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
List[float]: Embedding vector
|
29 |
+
"""
|
30 |
+
pass
|
src/embeddings/huggingface_embedding.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/embeddings/huggingface_embedding.py
|
2 |
+
from typing import List
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
|
5 |
+
from .base_embedding import BaseEmbedding
|
6 |
+
|
7 |
+
class HuggingFaceEmbedding(BaseEmbedding):
|
8 |
+
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
9 |
+
"""
|
10 |
+
Initialize HuggingFace embedding model
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model_name (str): Name of the embedding model
|
14 |
+
"""
|
15 |
+
self.model = SentenceTransformer(model_name)
|
16 |
+
|
17 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
18 |
+
"""
|
19 |
+
Embed a list of documents
|
20 |
+
|
21 |
+
Args:
|
22 |
+
texts (List[str]): List of texts to embed
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
List[List[float]]: List of embeddings
|
26 |
+
"""
|
27 |
+
return self.model.encode(texts).tolist()
|
28 |
+
|
29 |
+
def embed_query(self, text: str) -> List[float]:
|
30 |
+
"""
|
31 |
+
Embed a single query
|
32 |
+
|
33 |
+
Args:
|
34 |
+
text (str): Text to embed
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
List[float]: Embedding vector
|
38 |
+
"""
|
39 |
+
return self.model.encode(text).tolist()
|
src/llms/__init__.py
ADDED
File without changes
|
src/llms/base_llm.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/llms/base_llm.py
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List, Optional, Dict, Any
|
4 |
+
|
5 |
+
class BaseLLM(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def generate(
|
8 |
+
self,
|
9 |
+
prompt: str,
|
10 |
+
max_tokens: Optional[int] = None,
|
11 |
+
temperature: float = 0.7,
|
12 |
+
**kwargs
|
13 |
+
) -> str:
|
14 |
+
"""
|
15 |
+
Generate a response based on the given prompt
|
16 |
+
|
17 |
+
Args:
|
18 |
+
prompt (str): Input prompt for the model
|
19 |
+
max_tokens (Optional[int]): Maximum number of tokens to generate
|
20 |
+
temperature (float): Sampling temperature for randomness
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: Generated response
|
24 |
+
"""
|
25 |
+
pass
|
26 |
+
|
27 |
+
@abstractmethod
|
28 |
+
def tokenize(self, text: str) -> List[str]:
|
29 |
+
"""
|
30 |
+
Tokenize the input text
|
31 |
+
|
32 |
+
Args:
|
33 |
+
text (str): Input text to tokenize
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
List[str]: List of tokens
|
37 |
+
"""
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def count_tokens(self, text: str) -> int:
|
42 |
+
"""
|
43 |
+
Count tokens in the input text
|
44 |
+
|
45 |
+
Args:
|
46 |
+
text (str): Input text to count tokens
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
int: Number of tokens
|
50 |
+
"""
|
51 |
+
pass
|
src/llms/ollama_llm.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/llms/ollama_llm.py
|
2 |
+
import requests
|
3 |
+
from typing import Optional, List
|
4 |
+
|
5 |
+
from .base_llm import BaseLLM
|
6 |
+
|
7 |
+
class OllamaLanguageModel(BaseLLM):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
base_url: str = 'http://localhost:11434',
|
11 |
+
model: str = 'llama2'
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Initialize Ollama Language Model
|
15 |
+
|
16 |
+
Args:
|
17 |
+
base_url (str): Base URL for Ollama API
|
18 |
+
model (str): Name of the Ollama model to use
|
19 |
+
"""
|
20 |
+
self.base_url = base_url
|
21 |
+
self.model = model
|
22 |
+
|
23 |
+
def generate(
|
24 |
+
self,
|
25 |
+
prompt: str,
|
26 |
+
max_tokens: Optional[int] = 150,
|
27 |
+
temperature: float = 0.7,
|
28 |
+
**kwargs
|
29 |
+
) -> str:
|
30 |
+
"""
|
31 |
+
Generate response using Ollama API
|
32 |
+
|
33 |
+
Args:
|
34 |
+
prompt (str): Input prompt
|
35 |
+
max_tokens (Optional[int]): Maximum tokens to generate
|
36 |
+
temperature (float): Sampling temperature
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: Generated response
|
40 |
+
"""
|
41 |
+
response = requests.post(
|
42 |
+
f"{self.base_url}/api/generate",
|
43 |
+
json={
|
44 |
+
"model": self.model,
|
45 |
+
"prompt": prompt,
|
46 |
+
"stream": False,
|
47 |
+
"options": {
|
48 |
+
"temperature": temperature,
|
49 |
+
"num_predict": max_tokens
|
50 |
+
}
|
51 |
+
}
|
52 |
+
)
|
53 |
+
|
54 |
+
response.raise_for_status()
|
55 |
+
return response.json().get('response', '').strip()
|
56 |
+
|
57 |
+
def tokenize(self, text: str) -> List[str]:
|
58 |
+
"""
|
59 |
+
Tokenize text
|
60 |
+
|
61 |
+
Args:
|
62 |
+
text (str): Input text to tokenize
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
List[str]: List of tokens
|
66 |
+
"""
|
67 |
+
# Simple tokenization
|
68 |
+
return text.split()
|
69 |
+
|
70 |
+
def count_tokens(self, text: str) -> int:
|
71 |
+
"""
|
72 |
+
Count tokens in the text
|
73 |
+
|
74 |
+
Args:
|
75 |
+
text (str): Input text to count tokens
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
int: Number of tokens
|
79 |
+
"""
|
80 |
+
return len(self.tokenize(text))
|
src/llms/openai_llm.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/llms/openai_llm.py
|
2 |
+
import openai
|
3 |
+
from typing import Optional, List
|
4 |
+
|
5 |
+
from .base_llm import BaseLLM
|
6 |
+
|
7 |
+
class OpenAILanguageModel(BaseLLM):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
api_key: str,
|
11 |
+
model: str = 'gpt-3.5-turbo'
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Initialize OpenAI Language Model
|
15 |
+
|
16 |
+
Args:
|
17 |
+
api_key (str): OpenAI API key
|
18 |
+
model (str): Name of the OpenAI model to use
|
19 |
+
"""
|
20 |
+
openai.api_key = api_key
|
21 |
+
self.model = model
|
22 |
+
|
23 |
+
def generate(
|
24 |
+
self,
|
25 |
+
prompt: str,
|
26 |
+
max_tokens: Optional[int] = 150,
|
27 |
+
temperature: float = 0.7,
|
28 |
+
**kwargs
|
29 |
+
) -> str:
|
30 |
+
"""
|
31 |
+
Generate response using OpenAI API
|
32 |
+
|
33 |
+
Args:
|
34 |
+
prompt (str): Input prompt
|
35 |
+
max_tokens (Optional[int]): Maximum tokens to generate
|
36 |
+
temperature (float): Sampling temperature
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: Generated response
|
40 |
+
"""
|
41 |
+
response = openai.ChatCompletion.create(
|
42 |
+
model=self.model,
|
43 |
+
messages=[{"role": "user", "content": prompt}],
|
44 |
+
max_tokens=max_tokens,
|
45 |
+
temperature=temperature,
|
46 |
+
**kwargs
|
47 |
+
)
|
48 |
+
|
49 |
+
return response.choices[0].message.content.strip()
|
50 |
+
|
51 |
+
def tokenize(self, text: str) -> List[str]:
|
52 |
+
"""
|
53 |
+
Tokenize text using OpenAI tokenizer
|
54 |
+
|
55 |
+
Args:
|
56 |
+
text (str): Input text to tokenize
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
List[str]: List of tokens
|
60 |
+
"""
|
61 |
+
# Note: This is a placeholder. OpenAI doesn't provide a direct
|
62 |
+
# tokenization method without making an API call.
|
63 |
+
return text.split()
|
64 |
+
|
65 |
+
def count_tokens(self, text: str) -> int:
|
66 |
+
"""
|
67 |
+
Count tokens in the text
|
68 |
+
|
69 |
+
Args:
|
70 |
+
text (str): Input text to count tokens
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
int: Number of tokens
|
74 |
+
"""
|
75 |
+
# Approximate token counting
|
76 |
+
return len(self.tokenize(text))
|
src/main.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/main.py
|
2 |
+
from fastapi import FastAPI, Depends, HTTPException
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
from .agents.rag_agent import RAGAgent
|
7 |
+
from .llms.openai_llm import OpenAILanguageModel
|
8 |
+
from .llms.ollama_llm import OllamaLanguageModel
|
9 |
+
from .embeddings.huggingface_embedding import HuggingFaceEmbedding
|
10 |
+
from .vectorstores.chroma_vectorstore import ChromaVectorStore
|
11 |
+
from config.config import settings
|
12 |
+
|
13 |
+
app = FastAPI(title="RAG Chatbot API")
|
14 |
+
|
15 |
+
class ChatRequest(BaseModel):
|
16 |
+
query: str
|
17 |
+
context_docs: Optional[List[str]] = None
|
18 |
+
llm_provider: str = 'openai'
|
19 |
+
|
20 |
+
class ChatResponse(BaseModel):
|
21 |
+
response: str
|
22 |
+
context: Optional[List[str]] = None
|
23 |
+
|
24 |
+
@app.post("/chat", response_model=ChatResponse)
|
25 |
+
async def chat_endpoint(request: ChatRequest):
|
26 |
+
try:
|
27 |
+
# Select LLM based on provider
|
28 |
+
if request.llm_provider == 'openai':
|
29 |
+
llm = OpenAILanguageModel(api_key=settings.OPENAI_API_KEY)
|
30 |
+
elif request.llm_provider == 'ollama':
|
31 |
+
llm = OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL)
|
32 |
+
else:
|
33 |
+
raise HTTPException(status_code=400, detail="Unsupported LLM provider")
|
34 |
+
|
35 |
+
# Initialize embedding and vector store
|
36 |
+
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
|
37 |
+
vector_store = ChromaVectorStore(
|
38 |
+
embedding_function=embedding.embed_documents,
|
39 |
+
persist_directory=settings.CHROMA_PATH
|
40 |
+
)
|
41 |
+
|
42 |
+
# Create RAG agent
|
43 |
+
rag_agent = RAGAgent(
|
44 |
+
llm=llm,
|
45 |
+
embedding=embedding,
|
46 |
+
vector_store=vector_store
|
47 |
+
)
|
48 |
+
|
49 |
+
# Process query
|
50 |
+
response = rag_agent.generate_response(
|
51 |
+
query=request.query,
|
52 |
+
context_docs=request.context_docs
|
53 |
+
)
|
54 |
+
|
55 |
+
return ChatResponse(
|
56 |
+
response=response.response,
|
57 |
+
context=response.context_docs
|
58 |
+
)
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
raise HTTPException(status_code=500, detail=str(e))
|
62 |
+
|
63 |
+
# Optional: Health check endpoint
|
64 |
+
@app.get("/health")
|
65 |
+
async def health_check():
|
66 |
+
return {"status": "healthy"}
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/document_loader.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/document_loader.py
|
2 |
+
import os
|
3 |
+
from typing import List, Union
|
4 |
+
import PyPDF2
|
5 |
+
import docx
|
6 |
+
|
7 |
+
def load_document(file_path: str) -> str:
|
8 |
+
"""
|
9 |
+
Load text from various document types
|
10 |
+
s
|
11 |
+
Args:
|
12 |
+
file_path (str): Path to the document file
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
str: Extracted text from the document
|
16 |
+
|
17 |
+
Raises:
|
18 |
+
ValueError: If file type is not supported
|
19 |
+
"""
|
20 |
+
# Get file extension
|
21 |
+
_, ext = os.path.splitext(file_path)
|
22 |
+
ext = ext.lower()
|
23 |
+
|
24 |
+
# Load based on file type
|
25 |
+
if ext == '.txt':
|
26 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
27 |
+
return f.read()
|
28 |
+
|
29 |
+
elif ext == '.pdf':
|
30 |
+
return load_pdf(file_path)
|
31 |
+
|
32 |
+
elif ext == '.docx':
|
33 |
+
return load_docx(file_path)
|
34 |
+
|
35 |
+
else:
|
36 |
+
raise ValueError(f"Unsupported file type: {ext}")
|
37 |
+
|
38 |
+
def load_pdf(file_path: str) -> str:
|
39 |
+
"""
|
40 |
+
Extract text from PDF file
|
41 |
+
|
42 |
+
Args:
|
43 |
+
file_path (str): Path to PDF file
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
str: Extracted text
|
47 |
+
"""
|
48 |
+
text = ""
|
49 |
+
with open(file_path, 'rb') as file:
|
50 |
+
reader = PyPDF2.PdfReader(file)
|
51 |
+
for page in reader.pages:
|
52 |
+
text += page.extract_text()
|
53 |
+
return text
|
54 |
+
|
55 |
+
def load_docx(file_path: str) -> str:
|
56 |
+
"""
|
57 |
+
Extract text from DOCX file
|
58 |
+
|
59 |
+
Args:
|
60 |
+
file_path (str): Path to DOCX file
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: Extracted text
|
64 |
+
"""
|
65 |
+
doc = docx.Document(file_path)
|
66 |
+
return '\n'.join([paragraph.text for paragraph in doc.paragraphs])
|
67 |
+
|
68 |
+
def load_documents_from_directory(
|
69 |
+
directory: str,
|
70 |
+
extensions: List[str] = ['.txt', '.pdf', '.docx']
|
71 |
+
) -> List[str]:
|
72 |
+
"""
|
73 |
+
Load all documents from a directory
|
74 |
+
|
75 |
+
Args:
|
76 |
+
directory (str): Path to the directory
|
77 |
+
extensions (List[str]): List of file extensions to load
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
List[str]: List of document texts
|
81 |
+
"""
|
82 |
+
documents = []
|
83 |
+
for filename in os.listdir(directory):
|
84 |
+
file_path = os.path.join(directory, filename)
|
85 |
+
if os.path.isfile(file_path) and any(filename.lower().endswith(ext) for ext in extensions):
|
86 |
+
try:
|
87 |
+
documents.append(load_document(file_path))
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error loading {filename}: {e}")
|
90 |
+
|
91 |
+
return documents
|
src/utils/logger.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/logger.py
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
def setup_logger(
|
7 |
+
name: str = "rag_chatbot",
|
8 |
+
log_level: str = "INFO",
|
9 |
+
log_file: Optional[str] = None
|
10 |
+
) -> logging.Logger:
|
11 |
+
"""
|
12 |
+
Set up a comprehensive logger for the application
|
13 |
+
|
14 |
+
Args:
|
15 |
+
name (str): Name of the logger
|
16 |
+
log_level (str): Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
17 |
+
log_file (Optional[str]): Path to log file (optional)
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
logging.Logger: Configured logger instance
|
21 |
+
"""
|
22 |
+
# Create logger
|
23 |
+
logger = logging.getLogger(name)
|
24 |
+
logger.setLevel(getattr(logging, log_level.upper()))
|
25 |
+
|
26 |
+
# Create formatters
|
27 |
+
console_formatter = logging.Formatter(
|
28 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
29 |
+
)
|
30 |
+
file_formatter = logging.Formatter(
|
31 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
|
32 |
+
)
|
33 |
+
|
34 |
+
# Console Handler
|
35 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
36 |
+
console_handler.setFormatter(console_formatter)
|
37 |
+
logger.addHandler(console_handler)
|
38 |
+
|
39 |
+
# File Handler (if log_file is provided)
|
40 |
+
if log_file:
|
41 |
+
file_handler = logging.FileHandler(log_file)
|
42 |
+
file_handler.setFormatter(file_formatter)
|
43 |
+
logger.addHandler(file_handler)
|
44 |
+
|
45 |
+
return logger
|
46 |
+
|
47 |
+
# Global logger instance
|
48 |
+
logger = setup_logger()
|
49 |
+
|
50 |
+
class AppException(Exception):
|
51 |
+
"""
|
52 |
+
Custom base exception for the application
|
53 |
+
"""
|
54 |
+
def __init__(self, message: str, error_code: Optional[str] = None):
|
55 |
+
"""
|
56 |
+
Initialize custom exception
|
57 |
+
|
58 |
+
Args:
|
59 |
+
message (str): Error message
|
60 |
+
error_code (Optional[str]): Optional error code
|
61 |
+
"""
|
62 |
+
self.message = message
|
63 |
+
self.error_code = error_code
|
64 |
+
super().__init__(self.message)
|
65 |
+
|
66 |
+
# Log the exception
|
67 |
+
logger.error(f"AppException: {message}")
|
68 |
+
|
69 |
+
class ConfigurationError(AppException):
|
70 |
+
"""Exception raised for configuration-related errors"""
|
71 |
+
pass
|
72 |
+
|
73 |
+
class LLMProviderError(AppException):
|
74 |
+
"""Exception raised for LLM provider-related errors"""
|
75 |
+
pass
|
76 |
+
|
77 |
+
class EmbeddingError(AppException):
|
78 |
+
"""Exception raised for embedding-related errors"""
|
79 |
+
pass
|
80 |
+
|
81 |
+
class VectorStoreError(AppException):
|
82 |
+
"""Exception raised for vector store-related errors"""
|
83 |
+
pass
|
src/utils/text_splitter.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/text_splitter.py
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
def split_text(
|
5 |
+
text: str,
|
6 |
+
chunk_size: int = 500,
|
7 |
+
overlap: int = 50
|
8 |
+
) -> List[str]:
|
9 |
+
"""
|
10 |
+
Split a long text into smaller chunks
|
11 |
+
|
12 |
+
Args:
|
13 |
+
text (str): Input text to split
|
14 |
+
chunk_size (int): Maximum size of each text chunk
|
15 |
+
overlap (int): Number of characters to overlap between chunks
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
List[str]: List of text chunks
|
19 |
+
"""
|
20 |
+
chunks = []
|
21 |
+
start = 0
|
22 |
+
|
23 |
+
while start < len(text):
|
24 |
+
# Extract chunk
|
25 |
+
chunk = text[start:start + chunk_size]
|
26 |
+
chunks.append(chunk)
|
27 |
+
|
28 |
+
# Move start position with overlap
|
29 |
+
start += chunk_size - overlap
|
30 |
+
|
31 |
+
return chunks
|
32 |
+
|
33 |
+
def clean_text(text: str) -> str:
|
34 |
+
"""
|
35 |
+
Clean and preprocess text
|
36 |
+
|
37 |
+
Args:
|
38 |
+
text (str): Input text to clean
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
str: Cleaned text
|
42 |
+
"""
|
43 |
+
# Remove extra whitespaces
|
44 |
+
text = ' '.join(text.split())
|
45 |
+
|
46 |
+
# Add more cleaning steps as needed
|
47 |
+
# For example:
|
48 |
+
# - Remove special characters
|
49 |
+
# - Convert to lowercase
|
50 |
+
# - Remove HTML tags
|
51 |
+
|
52 |
+
return text
|
src/vctorstores/__init__.py
ADDED
File without changes
|
src/vctorstores/base_vectorstore.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/vectorstores/base_vectorstore.py
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List, Callable, Any
|
4 |
+
|
5 |
+
class BaseVectorStore(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def add_documents(
|
8 |
+
self,
|
9 |
+
documents: List[str],
|
10 |
+
embeddings: List[List[float]]
|
11 |
+
) -> None:
|
12 |
+
"""
|
13 |
+
Add documents to the vector store
|
14 |
+
|
15 |
+
Args:
|
16 |
+
documents (List[str]): List of document texts
|
17 |
+
embeddings (List[List[float]]): Corresponding embeddings
|
18 |
+
"""
|
19 |
+
pass
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def similarity_search(
|
23 |
+
self,
|
24 |
+
query_embedding: List[float],
|
25 |
+
top_k: int = 3
|
26 |
+
) -> List[str]:
|
27 |
+
"""
|
28 |
+
Perform similarity search
|
29 |
+
|
30 |
+
Args:
|
31 |
+
query_embedding (List[float]): Embedding of the query
|
32 |
+
top_k (int): Number of top similar documents to retrieve
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
List[str]: List of most similar documents
|
36 |
+
"""
|
37 |
+
pass
|
src/vctorstores/chroma_vectorstore.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/vectorstores/chroma_vectorstore.py
|
2 |
+
import chromadb
|
3 |
+
from typing import List, Callable, Any
|
4 |
+
|
5 |
+
from .base_vectorstore import BaseVectorStore
|
6 |
+
|
7 |
+
class ChromaVectorStore(BaseVectorStore):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
embedding_function: Callable[[List[str]], List[List[float]]],
|
11 |
+
persist_directory: str = './chroma_db'
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Initialize Chroma Vector Store
|
15 |
+
|
16 |
+
Args:
|
17 |
+
embedding_function (Callable): Function to generate embeddings
|
18 |
+
persist_directory (str): Directory to persist the vector store
|
19 |
+
"""
|
20 |
+
self.client = chromadb.PersistentClient(path=persist_directory)
|
21 |
+
self.collection = self.client.get_or_create_collection(name="documents")
|
22 |
+
self.embedding_function = embedding_function
|
23 |
+
|
24 |
+
def add_documents(
|
25 |
+
self,
|
26 |
+
documents: List[str],
|
27 |
+
embeddings: List[List[float]] = None
|
28 |
+
) -> None:
|
29 |
+
"""
|
30 |
+
Add documents to the vector store
|
31 |
+
|
32 |
+
Args:
|
33 |
+
documents (List[str]): List of document texts
|
34 |
+
embeddings (List[List[float]], optional): Pre-computed embeddings
|
35 |
+
"""
|
36 |
+
if not embeddings:
|
37 |
+
embeddings = self.embedding_function(documents)
|
38 |
+
|
39 |
+
# Generate unique IDs
|
40 |
+
ids = [f"doc_{i}" for i in range(len(documents))]
|
41 |
+
|
42 |
+
self.collection.add(
|
43 |
+
documents=documents,
|
44 |
+
embeddings=embeddings,
|
45 |
+
ids=ids
|
46 |
+
)
|
47 |
+
|
48 |
+
def similarity_search(
|
49 |
+
self,
|
50 |
+
query_embedding: List[float],
|
51 |
+
top_k: int = 3
|
52 |
+
) -> List[str]:
|
53 |
+
"""
|
54 |
+
Perform similarity search
|
55 |
+
|
56 |
+
Args:
|
57 |
+
query_embedding (List[float]): Embedding of the query
|
58 |
+
top_k (int): Number of top similar documents to retrieve
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
List[str]: List of most similar documents
|
62 |
+
"""
|
63 |
+
results = self.collection.query(
|
64 |
+
query_embeddings=[query_embedding],
|
65 |
+
n_results=top_k
|
66 |
+
)
|
67 |
+
|
68 |
+
return results.get('documents', [[]])[0]
|