Spaces:
Running
Running
Commit
·
e87abff
1
Parent(s):
926ee1d
Added support for multiple LLMs
Browse files- Modified `.vscode/settings.json` for updated linting and formatting settings.
- Added `Chatbot-design.png` for design reference.
- Added `chat_history.db` for storing chat history.
- Added `chroma_db/chroma.sqlite3` for Chroma vector store persistence.
- Updated `src/agents/rag_agent.py` to support multiple LLMs.
- Updated `src/embeddings/__init__.py` for embedding initialization.
- Added compiled Python files to `.gitignore`.
- .vscode/settings.json +2 -1
- Chatbot-design.png +0 -0
- chat_history.db +0 -0
- config/__pycache__/__init__.cpython-312.pyc +0 -0
- config/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/main.cpython-312.pyc +0 -0
- src/agents/__pycache__/__init__.cpython-312.pyc +0 -0
- src/agents/__pycache__/rag_agent.cpython-312.pyc +0 -0
- src/agents/rag_agent.py +3 -3
- src/embeddings/__init__.py +4 -0
- src/embeddings/__pycache__/__init__.cpython-312.pyc +0 -0
- src/embeddings/__pycache__/base_embedding.cpython-312.pyc +0 -0
- src/embeddings/__pycache__/huggingface_embedding.cpython-312.pyc +0 -0
- src/llms/__pycache__/__init__.cpython-312.pyc +0 -0
- src/llms/__pycache__/base_llm.cpython-312.pyc +0 -0
- src/llms/__pycache__/bert_llm.cpython-312.pyc +0 -0
- src/llms/__pycache__/falcon_llm.cpython-312.pyc +0 -0
- src/llms/__pycache__/llama_llm.cpython-312.pyc +0 -0
- src/llms/__pycache__/ollama_llm.cpython-312.pyc +0 -0
- src/llms/__pycache__/openai_llm.cpython-312.pyc +0 -0
- src/llms/bert_llm.py +44 -0
- src/llms/falcon_llm.py +39 -0
- src/llms/llama_llm.py +39 -0
- src/llms/openai_llm.py +4 -10
- src/main.py +342 -27
- src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- src/utils/__pycache__/conversation_summarizer.cpython-312.pyc +0 -0
- src/utils/__pycache__/document_processor.cpython-312.pyc +0 -0
- src/utils/__pycache__/logger.cpython-312.pyc +0 -0
- src/utils/__pycache__/text_splitter.cpython-312.pyc +0 -0
- src/utils/conversation_summarizer.py +128 -0
- src/utils/document_processor.py +262 -0
- src/vctorstores/__init__.py +0 -0
- src/vectorstores/__init__.py +3 -0
- src/vectorstores/__pycache__/__init__.cpython-312.pyc +0 -0
- src/vectorstores/__pycache__/base_vectorstore.cpython-312.pyc +0 -0
- src/vectorstores/__pycache__/chroma_vectorstore.cpython-312.pyc +0 -0
- src/{vctorstores → vectorstores}/base_vectorstore.py +0 -0
- src/{vctorstores → vectorstores}/chroma_vectorstore.py +0 -0
.vscode/settings.json
CHANGED
@@ -8,5 +8,6 @@
|
|
8 |
"tests"
|
9 |
],
|
10 |
"python.testing.unittestEnabled": false,
|
11 |
-
"python.testing.pytestEnabled": true
|
|
|
12 |
}
|
|
|
8 |
"tests"
|
9 |
],
|
10 |
"python.testing.unittestEnabled": false,
|
11 |
+
"python.testing.pytestEnabled": true,
|
12 |
+
"git.ignoreLimitWarning": true
|
13 |
}
|
Chatbot-design.png
ADDED
![]() |
chat_history.db
ADDED
Binary file (12.3 kB). View file
|
|
config/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (208 Bytes). View file
|
|
config/__pycache__/config.cpython-312.pyc
ADDED
Binary file (1.23 kB). View file
|
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (205 Bytes). View file
|
|
src/__pycache__/main.cpython-312.pyc
ADDED
Binary file (20.4 kB). View file
|
|
src/agents/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (212 Bytes). View file
|
|
src/agents/__pycache__/rag_agent.cpython-312.pyc
ADDED
Binary file (3.76 kB). View file
|
|
src/agents/rag_agent.py
CHANGED
@@ -3,9 +3,9 @@ from dataclasses import dataclass
|
|
3 |
from typing import List, Optional
|
4 |
|
5 |
from ..llms.base_llm import BaseLLM
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from
|
9 |
|
10 |
@dataclass
|
11 |
class RAGResponse:
|
|
|
3 |
from typing import List, Optional
|
4 |
|
5 |
from ..llms.base_llm import BaseLLM
|
6 |
+
from src.embeddings.base_embedding import BaseEmbedding
|
7 |
+
from src.vectorstores.base_vectorstore import BaseVectorStore
|
8 |
+
from src.utils.text_splitter import split_text
|
9 |
|
10 |
@dataclass
|
11 |
class RAGResponse:
|
src/embeddings/__init__.py
CHANGED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/embeddings/__init__.py
|
2 |
+
from .base_embedding import BaseEmbedding
|
3 |
+
|
4 |
+
__all__ = ['BaseEmbedding']
|
src/embeddings/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (299 Bytes). View file
|
|
src/embeddings/__pycache__/base_embedding.cpython-312.pyc
ADDED
Binary file (1.34 kB). View file
|
|
src/embeddings/__pycache__/huggingface_embedding.cpython-312.pyc
ADDED
Binary file (1.98 kB). View file
|
|
src/llms/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (210 Bytes). View file
|
|
src/llms/__pycache__/base_llm.cpython-312.pyc
ADDED
Binary file (1.94 kB). View file
|
|
src/llms/__pycache__/bert_llm.cpython-312.pyc
ADDED
Binary file (2.35 kB). View file
|
|
src/llms/__pycache__/falcon_llm.cpython-312.pyc
ADDED
Binary file (2.08 kB). View file
|
|
src/llms/__pycache__/llama_llm.cpython-312.pyc
ADDED
Binary file (2.07 kB). View file
|
|
src/llms/__pycache__/ollama_llm.cpython-312.pyc
ADDED
Binary file (2.9 kB). View file
|
|
src/llms/__pycache__/openai_llm.cpython-312.pyc
ADDED
Binary file (2.94 kB). View file
|
|
src/llms/bert_llm.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/llms/bert_llm.py
|
2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
3 |
+
from typing import Optional, List
|
4 |
+
from .base_llm import BaseLLM
|
5 |
+
|
6 |
+
class BERTLanguageModel(BaseLLM):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
model_name: str = "bert-base-uncased",
|
10 |
+
max_length: int = 512
|
11 |
+
):
|
12 |
+
"""Initialize BERT model"""
|
13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
15 |
+
self.generator = pipeline(
|
16 |
+
"text-generation",
|
17 |
+
model=self.model,
|
18 |
+
tokenizer=self.tokenizer
|
19 |
+
)
|
20 |
+
self.max_length = max_length
|
21 |
+
|
22 |
+
def generate(
|
23 |
+
self,
|
24 |
+
prompt: str,
|
25 |
+
max_tokens: Optional[int] = None,
|
26 |
+
temperature: float = 0.7,
|
27 |
+
**kwargs
|
28 |
+
) -> str:
|
29 |
+
"""Generate text using BERT"""
|
30 |
+
output = self.generator(
|
31 |
+
prompt,
|
32 |
+
max_length=max_tokens or self.max_length,
|
33 |
+
temperature=temperature,
|
34 |
+
**kwargs
|
35 |
+
)
|
36 |
+
return output[0]['generated_text']
|
37 |
+
|
38 |
+
def tokenize(self, text: str) -> List[str]:
|
39 |
+
"""Tokenize text using BERT tokenizer"""
|
40 |
+
return self.tokenizer.tokenize(text)
|
41 |
+
|
42 |
+
def count_tokens(self, text: str) -> int:
|
43 |
+
"""Count tokens in text"""
|
44 |
+
return len(self.tokenizer.encode(text))
|
src/llms/falcon_llm.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/llms/falcon_llm.py
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
from typing import Optional, List
|
5 |
+
from .base_llm import BaseLLM
|
6 |
+
|
7 |
+
class FalconLanguageModel(BaseLLM):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
model_name: str = "tiiuae/falcon-7b",
|
11 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
):
|
13 |
+
"""Initialize Falcon model"""
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
model_name,
|
17 |
+
device_map=device,
|
18 |
+
torch_dtype=torch.float16
|
19 |
+
)
|
20 |
+
self.device = device
|
21 |
+
|
22 |
+
def generate(
|
23 |
+
self,
|
24 |
+
prompt: str,
|
25 |
+
max_tokens: Optional[int] = None,
|
26 |
+
temperature: float = 0.7,
|
27 |
+
**kwargs
|
28 |
+
) -> str:
|
29 |
+
"""Generate text using Falcon"""
|
30 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
31 |
+
|
32 |
+
outputs = self.model.generate(
|
33 |
+
**inputs,
|
34 |
+
max_length=max_tokens if max_tokens else 100,
|
35 |
+
temperature=temperature,
|
36 |
+
**kwargs
|
37 |
+
)
|
38 |
+
|
39 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
src/llms/llama_llm.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/llms/llama_llm.py
|
2 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
3 |
+
import torch
|
4 |
+
from typing import Optional, List
|
5 |
+
from .base_llm import BaseLLM
|
6 |
+
|
7 |
+
class LlamaLanguageModel(BaseLLM):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
model_name: str = "meta-llama/Llama-2-7b",
|
11 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
):
|
13 |
+
"""Initialize Llama model"""
|
14 |
+
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
15 |
+
self.model = LlamaForCausalLM.from_pretrained(
|
16 |
+
model_name,
|
17 |
+
device_map=device,
|
18 |
+
torch_dtype=torch.float16
|
19 |
+
)
|
20 |
+
self.device = device
|
21 |
+
|
22 |
+
def generate(
|
23 |
+
self,
|
24 |
+
prompt: str,
|
25 |
+
max_tokens: Optional[int] = None,
|
26 |
+
temperature: float = 0.7,
|
27 |
+
**kwargs
|
28 |
+
) -> str:
|
29 |
+
"""Generate text using Llama"""
|
30 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
31 |
+
|
32 |
+
outputs = self.model.generate(
|
33 |
+
**inputs,
|
34 |
+
max_length=max_tokens if max_tokens else 100,
|
35 |
+
temperature=temperature,
|
36 |
+
**kwargs
|
37 |
+
)
|
38 |
+
|
39 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
src/llms/openai_llm.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1 |
# src/llms/openai_llm.py
|
2 |
import openai
|
3 |
from typing import Optional, List
|
|
|
4 |
|
5 |
from .base_llm import BaseLLM
|
6 |
|
7 |
class OpenAILanguageModel(BaseLLM):
|
8 |
-
def __init__(
|
9 |
-
self,
|
10 |
-
api_key: str,
|
11 |
-
model: str = 'gpt-3.5-turbo'
|
12 |
-
):
|
13 |
"""
|
14 |
Initialize OpenAI Language Model
|
15 |
|
@@ -17,7 +14,7 @@ class OpenAILanguageModel(BaseLLM):
|
|
17 |
api_key (str): OpenAI API key
|
18 |
model (str): Name of the OpenAI model to use
|
19 |
"""
|
20 |
-
|
21 |
self.model = model
|
22 |
|
23 |
def generate(
|
@@ -38,7 +35,7 @@ class OpenAILanguageModel(BaseLLM):
|
|
38 |
Returns:
|
39 |
str: Generated response
|
40 |
"""
|
41 |
-
response =
|
42 |
model=self.model,
|
43 |
messages=[{"role": "user", "content": prompt}],
|
44 |
max_tokens=max_tokens,
|
@@ -58,8 +55,6 @@ class OpenAILanguageModel(BaseLLM):
|
|
58 |
Returns:
|
59 |
List[str]: List of tokens
|
60 |
"""
|
61 |
-
# Note: This is a placeholder. OpenAI doesn't provide a direct
|
62 |
-
# tokenization method without making an API call.
|
63 |
return text.split()
|
64 |
|
65 |
def count_tokens(self, text: str) -> int:
|
@@ -72,5 +67,4 @@ class OpenAILanguageModel(BaseLLM):
|
|
72 |
Returns:
|
73 |
int: Number of tokens
|
74 |
"""
|
75 |
-
# Approximate token counting
|
76 |
return len(self.tokenize(text))
|
|
|
1 |
# src/llms/openai_llm.py
|
2 |
import openai
|
3 |
from typing import Optional, List
|
4 |
+
from openai import OpenAI # Import the new client
|
5 |
|
6 |
from .base_llm import BaseLLM
|
7 |
|
8 |
class OpenAILanguageModel(BaseLLM):
|
9 |
+
def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
Initialize OpenAI Language Model
|
12 |
|
|
|
14 |
api_key (str): OpenAI API key
|
15 |
model (str): Name of the OpenAI model to use
|
16 |
"""
|
17 |
+
self.client = OpenAI(api_key=api_key) # Use the new client
|
18 |
self.model = model
|
19 |
|
20 |
def generate(
|
|
|
35 |
Returns:
|
36 |
str: Generated response
|
37 |
"""
|
38 |
+
response = self.client.chat.completions.create(
|
39 |
model=self.model,
|
40 |
messages=[{"role": "user", "content": prompt}],
|
41 |
max_tokens=max_tokens,
|
|
|
55 |
Returns:
|
56 |
List[str]: List of tokens
|
57 |
"""
|
|
|
|
|
58 |
return text.split()
|
59 |
|
60 |
def count_tokens(self, text: str) -> int:
|
|
|
67 |
Returns:
|
68 |
int: Number of tokens
|
69 |
"""
|
|
|
70 |
return len(self.tokenize(text))
|
src/main.py
CHANGED
@@ -1,66 +1,381 @@
|
|
1 |
# src/main.py
|
2 |
-
from fastapi import FastAPI, Depends,
|
|
|
3 |
from pydantic import BaseModel
|
4 |
-
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
|
|
6 |
from .agents.rag_agent import RAGAgent
|
7 |
from .llms.openai_llm import OpenAILanguageModel
|
8 |
from .llms.ollama_llm import OllamaLanguageModel
|
|
|
|
|
|
|
9 |
from .embeddings.huggingface_embedding import HuggingFaceEmbedding
|
10 |
from .vectorstores.chroma_vectorstore import ChromaVectorStore
|
|
|
|
|
|
|
11 |
from config.config import settings
|
12 |
|
13 |
app = FastAPI(title="RAG Chatbot API")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
class ChatRequest(BaseModel):
|
16 |
query: str
|
17 |
-
context_docs: Optional[List[str]] = None
|
18 |
llm_provider: str = 'openai'
|
|
|
|
|
|
|
|
|
19 |
|
20 |
class ChatResponse(BaseModel):
|
21 |
response: str
|
22 |
context: Optional[List[str]] = None
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
try:
|
27 |
-
# Select LLM based on provider
|
28 |
-
if request.llm_provider == 'openai':
|
29 |
-
llm = OpenAILanguageModel(api_key=settings.OPENAI_API_KEY)
|
30 |
-
elif request.llm_provider == 'ollama':
|
31 |
-
llm = OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL)
|
32 |
-
else:
|
33 |
-
raise HTTPException(status_code=400, detail="Unsupported LLM provider")
|
34 |
-
|
35 |
-
# Initialize embedding and vector store
|
36 |
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
|
37 |
vector_store = ChromaVectorStore(
|
38 |
-
embedding_function=embedding.embed_documents,
|
39 |
persist_directory=settings.CHROMA_PATH
|
40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
# Create RAG agent
|
43 |
rag_agent = RAGAgent(
|
44 |
-
llm=llm,
|
45 |
-
embedding=
|
46 |
vector_store=vector_store
|
47 |
)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
)
|
54 |
|
55 |
return ChatResponse(
|
56 |
-
response=response.response,
|
57 |
-
context=response.context_docs
|
|
|
|
|
|
|
|
|
58 |
)
|
59 |
-
|
60 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
raise HTTPException(status_code=500, detail=str(e))
|
62 |
|
63 |
-
# Optional: Health check endpoint
|
64 |
@app.get("/health")
|
65 |
async def health_check():
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# src/main.py
|
2 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, BackgroundTasks
|
3 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
4 |
from pydantic import BaseModel
|
5 |
+
from typing import List, Optional, AsyncGenerator, Dict
|
6 |
+
import asyncio
|
7 |
+
import json
|
8 |
+
import uuid
|
9 |
+
from datetime import datetime
|
10 |
+
import aiosqlite
|
11 |
+
from pathlib import Path
|
12 |
+
import shutil
|
13 |
+
import os
|
14 |
|
15 |
+
# Import custom modules
|
16 |
from .agents.rag_agent import RAGAgent
|
17 |
from .llms.openai_llm import OpenAILanguageModel
|
18 |
from .llms.ollama_llm import OllamaLanguageModel
|
19 |
+
from .llms.bert_llm import BERTLanguageModel
|
20 |
+
from .llms.falcon_llm import FalconLanguageModel
|
21 |
+
from .llms.llama_llm import LlamaLanguageModel
|
22 |
from .embeddings.huggingface_embedding import HuggingFaceEmbedding
|
23 |
from .vectorstores.chroma_vectorstore import ChromaVectorStore
|
24 |
+
from .utils.document_processor import DocumentProcessor
|
25 |
+
from .utils.conversation_summarizer import ConversationSummarizer
|
26 |
+
from .utils.logger import logger
|
27 |
from config.config import settings
|
28 |
|
29 |
app = FastAPI(title="RAG Chatbot API")
|
30 |
|
31 |
+
# Initialize core components
|
32 |
+
doc_processor = DocumentProcessor(
|
33 |
+
chunk_size=1000,
|
34 |
+
chunk_overlap=200,
|
35 |
+
max_file_size=10 * 1024 * 1024
|
36 |
+
)
|
37 |
+
summarizer = ConversationSummarizer()
|
38 |
+
|
39 |
+
# Pydantic models
|
40 |
class ChatRequest(BaseModel):
|
41 |
query: str
|
|
|
42 |
llm_provider: str = 'openai'
|
43 |
+
max_context_docs: int = 3
|
44 |
+
temperature: float = 0.7
|
45 |
+
stream: bool = False
|
46 |
+
conversation_id: Optional[str] = None
|
47 |
|
48 |
class ChatResponse(BaseModel):
|
49 |
response: str
|
50 |
context: Optional[List[str]] = None
|
51 |
+
sources: Optional[List[Dict[str, str]]] = None
|
52 |
+
conversation_id: str
|
53 |
+
timestamp: datetime
|
54 |
+
relevant_doc_scores: Optional[List[float]] = None
|
55 |
|
56 |
+
class DocumentResponse(BaseModel):
|
57 |
+
message: str
|
58 |
+
document_id: str
|
59 |
+
status: str
|
60 |
+
document_info: Optional[dict] = None
|
61 |
+
|
62 |
+
class BatchUploadResponse(BaseModel):
|
63 |
+
message: str
|
64 |
+
processed_files: List[DocumentResponse]
|
65 |
+
failed_files: List[dict]
|
66 |
+
|
67 |
+
class SummarizeRequest(BaseModel):
|
68 |
+
conversation_id: str
|
69 |
+
include_metadata: bool = True
|
70 |
+
|
71 |
+
class SummaryResponse(BaseModel):
|
72 |
+
summary: str
|
73 |
+
key_insights: Dict
|
74 |
+
metadata: Optional[Dict] = None
|
75 |
+
|
76 |
+
class FeedbackRequest(BaseModel):
|
77 |
+
rating: int
|
78 |
+
feedback: Optional[str] = None
|
79 |
+
|
80 |
+
# Database initialization
|
81 |
+
async def init_db():
|
82 |
+
async with aiosqlite.connect('chat_history.db') as db:
|
83 |
+
await db.execute('''
|
84 |
+
CREATE TABLE IF NOT EXISTS chat_history (
|
85 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
86 |
+
conversation_id TEXT,
|
87 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
88 |
+
query TEXT,
|
89 |
+
response TEXT,
|
90 |
+
context TEXT,
|
91 |
+
sources TEXT,
|
92 |
+
llm_provider TEXT,
|
93 |
+
feedback TEXT,
|
94 |
+
rating INTEGER
|
95 |
+
)
|
96 |
+
''')
|
97 |
+
await db.commit()
|
98 |
+
|
99 |
+
# Utility functions
|
100 |
+
def get_llm_instance(provider: str):
|
101 |
+
"""Get LLM instance based on provider"""
|
102 |
+
llm_map = {
|
103 |
+
'openai': lambda: OpenAILanguageModel(api_key=settings.OPENAI_API_KEY),
|
104 |
+
'ollama': lambda: OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL),
|
105 |
+
'bert': lambda: BERTLanguageModel(),
|
106 |
+
'falcon': lambda: FalconLanguageModel(),
|
107 |
+
'llama': lambda: LlamaLanguageModel(),
|
108 |
+
}
|
109 |
+
|
110 |
+
if provider not in llm_map:
|
111 |
+
raise ValueError(f"Unsupported LLM provider: {provider}")
|
112 |
+
return llm_map[provider]()
|
113 |
+
|
114 |
+
async def get_vector_store():
|
115 |
+
"""Initialize and return vector store with embedding model."""
|
116 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
|
118 |
vector_store = ChromaVectorStore(
|
119 |
+
embedding_function=embedding.embed_documents,
|
120 |
persist_directory=settings.CHROMA_PATH
|
121 |
)
|
122 |
+
return vector_store, embedding
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error initializing vector store: {str(e)}")
|
125 |
+
raise HTTPException(status_code=500, detail="Failed to initialize vector store")
|
126 |
+
|
127 |
+
async def process_and_store_document(
|
128 |
+
file_path: Path,
|
129 |
+
vector_store: ChromaVectorStore,
|
130 |
+
document_id: str
|
131 |
+
):
|
132 |
+
"""Process document and store in vector database."""
|
133 |
+
try:
|
134 |
+
processed_doc = await doc_processor.process_document(file_path)
|
135 |
+
|
136 |
+
vector_store.add_documents(
|
137 |
+
documents=processed_doc['chunks'],
|
138 |
+
metadatas=[{
|
139 |
+
'document_id': document_id,
|
140 |
+
'chunk_id': i,
|
141 |
+
'source': str(file_path.name),
|
142 |
+
'metadata': processed_doc['metadata']
|
143 |
+
} for i in range(len(processed_doc['chunks']))],
|
144 |
+
ids=[f"{document_id}_chunk_{i}" for i in range(len(processed_doc['chunks']))]
|
145 |
+
)
|
146 |
+
|
147 |
+
return processed_doc
|
148 |
+
finally:
|
149 |
+
if file_path.exists():
|
150 |
+
file_path.unlink()
|
151 |
+
|
152 |
+
async def store_chat_history(
|
153 |
+
conversation_id: str,
|
154 |
+
query: str,
|
155 |
+
response: str,
|
156 |
+
context: List[str],
|
157 |
+
sources: List[Dict],
|
158 |
+
llm_provider: str
|
159 |
+
):
|
160 |
+
"""Store chat history in database"""
|
161 |
+
async with aiosqlite.connect('chat_history.db') as db:
|
162 |
+
await db.execute(
|
163 |
+
'''INSERT INTO chat_history
|
164 |
+
(conversation_id, query, response, context, sources, llm_provider)
|
165 |
+
VALUES (?, ?, ?, ?, ?, ?)''',
|
166 |
+
(conversation_id, query, response, json.dumps(context),
|
167 |
+
json.dumps(sources), llm_provider)
|
168 |
+
)
|
169 |
+
await db.commit()
|
170 |
+
|
171 |
+
# Endpoints
|
172 |
+
@app.post("/documents/upload", response_model=BatchUploadResponse)
|
173 |
+
async def upload_documents(
|
174 |
+
files: List[UploadFile] = File(...),
|
175 |
+
background_tasks: BackgroundTasks = BackgroundTasks()
|
176 |
+
):
|
177 |
+
"""Upload and process multiple documents"""
|
178 |
+
try:
|
179 |
+
vector_store, _ = await get_vector_store()
|
180 |
+
upload_dir = Path("temp_uploads")
|
181 |
+
upload_dir.mkdir(exist_ok=True)
|
182 |
+
|
183 |
+
processed_files = []
|
184 |
+
failed_files = []
|
185 |
+
|
186 |
+
for file in files:
|
187 |
+
try:
|
188 |
+
document_id = str(uuid.uuid4())
|
189 |
+
|
190 |
+
if not any(file.filename.lower().endswith(ext)
|
191 |
+
for ext in doc_processor.supported_formats):
|
192 |
+
failed_files.append({
|
193 |
+
"filename": file.filename,
|
194 |
+
"error": "Unsupported file format"
|
195 |
+
})
|
196 |
+
continue
|
197 |
+
|
198 |
+
temp_path = upload_dir / f"{document_id}_{file.filename}"
|
199 |
+
with open(temp_path, "wb") as buffer:
|
200 |
+
shutil.copyfileobj(file.file, buffer)
|
201 |
+
|
202 |
+
background_tasks.add_task(
|
203 |
+
process_and_store_document,
|
204 |
+
temp_path,
|
205 |
+
vector_store,
|
206 |
+
document_id
|
207 |
+
)
|
208 |
+
|
209 |
+
processed_files.append(
|
210 |
+
DocumentResponse(
|
211 |
+
message="Document queued for processing",
|
212 |
+
document_id=document_id,
|
213 |
+
status="processing",
|
214 |
+
document_info={
|
215 |
+
"original_filename": file.filename,
|
216 |
+
"size": os.path.getsize(temp_path),
|
217 |
+
"content_type": file.content_type
|
218 |
+
}
|
219 |
+
)
|
220 |
+
)
|
221 |
+
|
222 |
+
except Exception as e:
|
223 |
+
logger.error(f"Error processing file {file.filename}: {str(e)}")
|
224 |
+
failed_files.append({
|
225 |
+
"filename": file.filename,
|
226 |
+
"error": str(e)
|
227 |
+
})
|
228 |
+
|
229 |
+
return BatchUploadResponse(
|
230 |
+
message=f"Processed {len(processed_files)} documents with {len(failed_files)} failures",
|
231 |
+
processed_files=processed_files,
|
232 |
+
failed_files=failed_files
|
233 |
+
)
|
234 |
+
|
235 |
+
except Exception as e:
|
236 |
+
logger.error(f"Error in document upload: {str(e)}")
|
237 |
+
raise HTTPException(status_code=500, detail=str(e))
|
238 |
+
|
239 |
+
finally:
|
240 |
+
if upload_dir.exists() and not any(upload_dir.iterdir()):
|
241 |
+
upload_dir.rmdir()
|
242 |
+
|
243 |
+
@app.post("/chat", response_model=ChatResponse)
|
244 |
+
async def chat_endpoint(
|
245 |
+
request: ChatRequest,
|
246 |
+
background_tasks: BackgroundTasks
|
247 |
+
):
|
248 |
+
"""Chat endpoint with RAG support"""
|
249 |
+
try:
|
250 |
+
vector_store, embedding_model = await get_vector_store()
|
251 |
+
llm = get_llm_instance(request.llm_provider)
|
252 |
|
|
|
253 |
rag_agent = RAGAgent(
|
254 |
+
llm=llm,
|
255 |
+
embedding=embedding_model,
|
256 |
vector_store=vector_store
|
257 |
)
|
258 |
|
259 |
+
if request.stream:
|
260 |
+
return StreamingResponse(
|
261 |
+
rag_agent.generate_streaming_response(request.query),
|
262 |
+
media_type="text/event-stream"
|
263 |
+
)
|
264 |
+
|
265 |
+
response = await rag_agent.generate_response(
|
266 |
+
query=request.query,
|
267 |
+
temperature=request.temperature
|
268 |
+
)
|
269 |
+
|
270 |
+
conversation_id = request.conversation_id or str(uuid.uuid4())
|
271 |
+
|
272 |
+
background_tasks.add_task(
|
273 |
+
store_chat_history,
|
274 |
+
conversation_id,
|
275 |
+
request.query,
|
276 |
+
response.response,
|
277 |
+
response.context_docs,
|
278 |
+
response.sources,
|
279 |
+
request.llm_provider
|
280 |
)
|
281 |
|
282 |
return ChatResponse(
|
283 |
+
response=response.response,
|
284 |
+
context=response.context_docs,
|
285 |
+
sources=response.sources,
|
286 |
+
conversation_id=conversation_id,
|
287 |
+
timestamp=datetime.now(),
|
288 |
+
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None
|
289 |
)
|
290 |
+
|
291 |
except Exception as e:
|
292 |
+
logger.error(f"Error in chat endpoint: {str(e)}")
|
293 |
+
raise HTTPException(status_code=500, detail=str(e))
|
294 |
+
|
295 |
+
@app.get("/chat/history/{conversation_id}")
|
296 |
+
async def get_conversation_history(conversation_id: str):
|
297 |
+
"""Get complete conversation history"""
|
298 |
+
async with aiosqlite.connect('chat_history.db') as db:
|
299 |
+
db.row_factory = aiosqlite.Row
|
300 |
+
async with db.execute(
|
301 |
+
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp',
|
302 |
+
(conversation_id,)
|
303 |
+
) as cursor:
|
304 |
+
history = await cursor.fetchall()
|
305 |
+
|
306 |
+
if not history:
|
307 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
308 |
+
|
309 |
+
return {
|
310 |
+
"conversation_id": conversation_id,
|
311 |
+
"messages": [dict(row) for row in history]
|
312 |
+
}
|
313 |
+
|
314 |
+
@app.post("/chat/summarize", response_model=SummaryResponse)
|
315 |
+
async def summarize_conversation(request: SummarizeRequest):
|
316 |
+
"""Generate a summary of a conversation"""
|
317 |
+
try:
|
318 |
+
async with aiosqlite.connect('chat_history.db') as db:
|
319 |
+
db.row_factory = aiosqlite.Row
|
320 |
+
async with db.execute(
|
321 |
+
'SELECT * FROM chat_history WHERE conversation_id = ? ORDER BY timestamp',
|
322 |
+
(request.conversation_id,)
|
323 |
+
) as cursor:
|
324 |
+
history = await cursor.fetchall()
|
325 |
+
|
326 |
+
if not history:
|
327 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
328 |
+
|
329 |
+
messages = [{
|
330 |
+
'role': 'user' if msg['query'] else 'assistant',
|
331 |
+
'content': msg['query'] or msg['response'],
|
332 |
+
'timestamp': msg['timestamp'],
|
333 |
+
'sources': json.loads(msg['sources']) if msg['sources'] else None
|
334 |
+
} for msg in history]
|
335 |
+
|
336 |
+
summary = await summarizer.summarize_conversation(
|
337 |
+
messages,
|
338 |
+
include_metadata=request.include_metadata
|
339 |
+
)
|
340 |
+
|
341 |
+
return SummaryResponse(**summary)
|
342 |
+
|
343 |
+
except Exception as e:
|
344 |
+
logger.error(f"Error generating summary: {str(e)}")
|
345 |
+
raise HTTPException(status_code=500, detail=str(e))
|
346 |
+
|
347 |
+
@app.post("/chat/feedback/{conversation_id}")
|
348 |
+
async def submit_feedback(
|
349 |
+
conversation_id: str,
|
350 |
+
feedback_request: FeedbackRequest
|
351 |
+
):
|
352 |
+
"""Submit feedback for a conversation"""
|
353 |
+
try:
|
354 |
+
async with aiosqlite.connect('chat_history.db') as db:
|
355 |
+
await db.execute(
|
356 |
+
'''UPDATE chat_history
|
357 |
+
SET feedback = ?, rating = ?
|
358 |
+
WHERE conversation_id = ?''',
|
359 |
+
(feedback_request.feedback, feedback_request.rating, conversation_id)
|
360 |
+
)
|
361 |
+
await db.commit()
|
362 |
+
|
363 |
+
return {"status": "Feedback submitted successfully"}
|
364 |
+
|
365 |
+
except Exception as e:
|
366 |
+
logger.error(f"Error submitting feedback: {str(e)}")
|
367 |
raise HTTPException(status_code=500, detail=str(e))
|
368 |
|
|
|
369 |
@app.get("/health")
|
370 |
async def health_check():
|
371 |
+
"""Health check endpoint"""
|
372 |
+
return {"status": "healthy"}
|
373 |
+
|
374 |
+
# Startup event
|
375 |
+
@app.on_event("startup")
|
376 |
+
async def startup_event():
|
377 |
+
await init_db()
|
378 |
+
|
379 |
+
if __name__ == "__main__":
|
380 |
+
import uvicorn
|
381 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (211 Bytes). View file
|
|
src/utils/__pycache__/conversation_summarizer.cpython-312.pyc
ADDED
Binary file (5.96 kB). View file
|
|
src/utils/__pycache__/document_processor.cpython-312.pyc
ADDED
Binary file (12.9 kB). View file
|
|
src/utils/__pycache__/logger.cpython-312.pyc
ADDED
Binary file (3.69 kB). View file
|
|
src/utils/__pycache__/text_splitter.cpython-312.pyc
ADDED
Binary file (1.43 kB). View file
|
|
src/utils/conversation_summarizer.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/conversation_summarizer.py
|
2 |
+
from typing import List, Dict
|
3 |
+
from transformers import pipeline
|
4 |
+
import numpy as np
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
class ConversationSummarizer:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
model_name: str = "facebook/bart-large-cnn",
|
11 |
+
max_length: int = 130,
|
12 |
+
min_length: int = 30
|
13 |
+
):
|
14 |
+
"""Initialize the summarizer"""
|
15 |
+
self.summarizer = pipeline(
|
16 |
+
"summarization",
|
17 |
+
model=model_name,
|
18 |
+
device=-1 # CPU
|
19 |
+
)
|
20 |
+
self.max_length = max_length
|
21 |
+
self.min_length = min_length
|
22 |
+
|
23 |
+
async def summarize_conversation(
|
24 |
+
self,
|
25 |
+
messages: List[Dict],
|
26 |
+
include_metadata: bool = True
|
27 |
+
) -> Dict:
|
28 |
+
"""
|
29 |
+
Summarize a conversation and provide key insights
|
30 |
+
"""
|
31 |
+
# Format conversation for summarization
|
32 |
+
formatted_convo = self._format_conversation(messages)
|
33 |
+
|
34 |
+
# Generate summary
|
35 |
+
summary = self.summarizer(
|
36 |
+
formatted_convo,
|
37 |
+
max_length=self.max_length,
|
38 |
+
min_length=self.min_length,
|
39 |
+
do_sample=False
|
40 |
+
)[0]['summary_text']
|
41 |
+
|
42 |
+
# Extract key insights
|
43 |
+
insights = self._extract_insights(messages)
|
44 |
+
|
45 |
+
# Generate metadata if requested
|
46 |
+
metadata = self._generate_metadata(messages) if include_metadata else {}
|
47 |
+
|
48 |
+
return {
|
49 |
+
'summary': summary,
|
50 |
+
'key_insights': insights,
|
51 |
+
'metadata': metadata
|
52 |
+
}
|
53 |
+
|
54 |
+
def _format_conversation(self, messages: List[Dict]) -> str:
|
55 |
+
"""Format conversation for summarization"""
|
56 |
+
formatted = []
|
57 |
+
for msg in messages:
|
58 |
+
role = msg.get('role', 'unknown')
|
59 |
+
content = msg.get('content', '')
|
60 |
+
formatted.append(f"{role}: {content}")
|
61 |
+
|
62 |
+
return "\n".join(formatted)
|
63 |
+
|
64 |
+
def _extract_insights(self, messages: List[Dict]) -> Dict:
|
65 |
+
"""Extract key insights from conversation"""
|
66 |
+
# Count message types
|
67 |
+
message_counts = {
|
68 |
+
'user': len([m for m in messages if m.get('role') == 'user']),
|
69 |
+
'assistant': len([m for m in messages if m.get('role') == 'assistant'])
|
70 |
+
}
|
71 |
+
|
72 |
+
# Calculate average message length
|
73 |
+
avg_length = np.mean([len(m.get('content', '')) for m in messages])
|
74 |
+
|
75 |
+
# Extract main topics (simplified)
|
76 |
+
topics = self._extract_topics(messages)
|
77 |
+
|
78 |
+
return {
|
79 |
+
'message_distribution': message_counts,
|
80 |
+
'average_message_length': int(avg_length),
|
81 |
+
'main_topics': topics,
|
82 |
+
'total_messages': len(messages)
|
83 |
+
}
|
84 |
+
|
85 |
+
def _extract_topics(self, messages: List[Dict]) -> List[str]:
|
86 |
+
"""Extract main topics from conversation"""
|
87 |
+
# Combine all messages
|
88 |
+
full_text = " ".join([m.get('content', '') for m in messages])
|
89 |
+
|
90 |
+
# Use the summarizer to extract main points
|
91 |
+
topics = self.summarizer(
|
92 |
+
full_text,
|
93 |
+
max_length=50,
|
94 |
+
min_length=10,
|
95 |
+
do_sample=False
|
96 |
+
)[0]['summary_text'].split('. ')
|
97 |
+
|
98 |
+
return topics
|
99 |
+
|
100 |
+
def _generate_metadata(self, messages: List[Dict]) -> Dict:
|
101 |
+
"""Generate conversation metadata"""
|
102 |
+
if not messages:
|
103 |
+
return {}
|
104 |
+
|
105 |
+
return {
|
106 |
+
'start_time': messages[0].get('timestamp', None),
|
107 |
+
'end_time': messages[-1].get('timestamp', None),
|
108 |
+
'duration_minutes': self._calculate_duration(messages),
|
109 |
+
'sources_used': self._extract_sources(messages)
|
110 |
+
}
|
111 |
+
|
112 |
+
def _calculate_duration(self, messages: List[Dict]) -> float:
|
113 |
+
"""Calculate conversation duration in minutes"""
|
114 |
+
try:
|
115 |
+
start_time = datetime.fromisoformat(messages[0].get('timestamp', ''))
|
116 |
+
end_time = datetime.fromisoformat(messages[-1].get('timestamp', ''))
|
117 |
+
return (end_time - start_time).total_seconds() / 60
|
118 |
+
except:
|
119 |
+
return 0
|
120 |
+
|
121 |
+
def _extract_sources(self, messages: List[Dict]) -> List[str]:
|
122 |
+
"""Extract unique sources used in conversation"""
|
123 |
+
sources = set()
|
124 |
+
for message in messages:
|
125 |
+
if message.get('sources'):
|
126 |
+
for source in message['sources']:
|
127 |
+
sources.add(source.get('filename', ''))
|
128 |
+
return list(sources)
|
src/utils/document_processor.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/document_processor.py
|
2 |
+
from typing import List, Dict, Optional, Union
|
3 |
+
import PyPDF2
|
4 |
+
import docx
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
import hashlib
|
9 |
+
import magic # python-magic library for file type detection
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
import requests
|
12 |
+
import csv
|
13 |
+
from datetime import datetime
|
14 |
+
import threading
|
15 |
+
from queue import Queue
|
16 |
+
import tiktoken
|
17 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
18 |
+
|
19 |
+
class DocumentProcessor:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
chunk_size: int = 1000,
|
23 |
+
chunk_overlap: int = 200,
|
24 |
+
max_file_size: int = 10 * 1024 * 1024, # 10MB
|
25 |
+
supported_formats: Optional[List[str]] = None
|
26 |
+
):
|
27 |
+
self.chunk_size = chunk_size
|
28 |
+
self.chunk_overlap = chunk_overlap
|
29 |
+
self.max_file_size = max_file_size
|
30 |
+
self.supported_formats = supported_formats or [
|
31 |
+
'.txt', '.pdf', '.docx', '.csv', '.json',
|
32 |
+
'.html', '.md', '.xml', '.rtf'
|
33 |
+
]
|
34 |
+
self.processing_queue = Queue()
|
35 |
+
self.processed_docs = {}
|
36 |
+
self._initialize_text_splitter()
|
37 |
+
|
38 |
+
def _initialize_text_splitter(self):
|
39 |
+
"""Initialize the text splitter with custom settings"""
|
40 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
41 |
+
chunk_size=self.chunk_size,
|
42 |
+
chunk_overlap=self.chunk_overlap,
|
43 |
+
length_function=len,
|
44 |
+
separators=["\n\n", "\n", " ", ""]
|
45 |
+
)
|
46 |
+
|
47 |
+
async def process_document(
|
48 |
+
self,
|
49 |
+
file_path: Union[str, Path],
|
50 |
+
metadata: Optional[Dict] = None
|
51 |
+
) -> Dict:
|
52 |
+
"""
|
53 |
+
Process a document with metadata and content extraction
|
54 |
+
"""
|
55 |
+
file_path = Path(file_path)
|
56 |
+
|
57 |
+
# Basic validation
|
58 |
+
if not self._validate_file(file_path):
|
59 |
+
raise ValueError(f"Invalid file: {file_path}")
|
60 |
+
|
61 |
+
# Extract content based on file type
|
62 |
+
content = self._extract_content(file_path)
|
63 |
+
|
64 |
+
# Generate document metadata
|
65 |
+
doc_metadata = self._generate_metadata(file_path, content, metadata)
|
66 |
+
|
67 |
+
# Split content into chunks
|
68 |
+
chunks = self.text_splitter.split_text(content)
|
69 |
+
|
70 |
+
# Calculate embeddings chunk hashes
|
71 |
+
chunk_hashes = [self._calculate_hash(chunk) for chunk in chunks]
|
72 |
+
|
73 |
+
return {
|
74 |
+
'content': content,
|
75 |
+
'chunks': chunks,
|
76 |
+
'chunk_hashes': chunk_hashes,
|
77 |
+
'metadata': doc_metadata,
|
78 |
+
'statistics': self._generate_statistics(content, chunks)
|
79 |
+
}
|
80 |
+
|
81 |
+
def _validate_file(self, file_path: Path) -> bool:
|
82 |
+
"""
|
83 |
+
Validate file type, size, and content
|
84 |
+
"""
|
85 |
+
if not file_path.exists():
|
86 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
87 |
+
|
88 |
+
if file_path.suffix.lower() not in self.supported_formats:
|
89 |
+
raise ValueError(f"Unsupported file format: {file_path.suffix}")
|
90 |
+
|
91 |
+
if file_path.stat().st_size > self.max_file_size:
|
92 |
+
raise ValueError(f"File too large: {file_path}")
|
93 |
+
|
94 |
+
# Check if file is not empty
|
95 |
+
if file_path.stat().st_size == 0:
|
96 |
+
raise ValueError(f"Empty file: {file_path}")
|
97 |
+
|
98 |
+
return True
|
99 |
+
|
100 |
+
def _extract_content(self, file_path: Path) -> str:
|
101 |
+
"""
|
102 |
+
Extract content from different file formats
|
103 |
+
"""
|
104 |
+
suffix = file_path.suffix.lower()
|
105 |
+
|
106 |
+
try:
|
107 |
+
if suffix == '.pdf':
|
108 |
+
return self._extract_pdf(file_path)
|
109 |
+
elif suffix == '.docx':
|
110 |
+
return self._extract_docx(file_path)
|
111 |
+
elif suffix == '.csv':
|
112 |
+
return self._extract_csv(file_path)
|
113 |
+
elif suffix == '.json':
|
114 |
+
return self._extract_json(file_path)
|
115 |
+
elif suffix == '.html':
|
116 |
+
return self._extract_html(file_path)
|
117 |
+
elif suffix == '.txt':
|
118 |
+
return file_path.read_text(encoding='utf-8')
|
119 |
+
else:
|
120 |
+
raise ValueError(f"Unsupported format: {suffix}")
|
121 |
+
except Exception as e:
|
122 |
+
raise Exception(f"Error extracting content from {file_path}: {str(e)}")
|
123 |
+
|
124 |
+
def _extract_pdf(self, file_path: Path) -> str:
|
125 |
+
"""Extract text from PDF with advanced features"""
|
126 |
+
text = ""
|
127 |
+
with open(file_path, 'rb') as file:
|
128 |
+
reader = PyPDF2.PdfReader(file)
|
129 |
+
metadata = reader.metadata
|
130 |
+
|
131 |
+
for page in reader.pages:
|
132 |
+
text += page.extract_text() + "\n\n"
|
133 |
+
|
134 |
+
# Extract images if available
|
135 |
+
if '/XObject' in page['/Resources']:
|
136 |
+
for obj in page['/Resources']['/XObject'].get_object():
|
137 |
+
if page['/Resources']['/XObject'][obj]['/Subtype'] == '/Image':
|
138 |
+
# Process images if needed
|
139 |
+
pass
|
140 |
+
|
141 |
+
return text.strip()
|
142 |
+
|
143 |
+
def _extract_docx(self, file_path: Path) -> str:
|
144 |
+
"""Extract text from DOCX with formatting"""
|
145 |
+
doc = docx.Document(file_path)
|
146 |
+
full_text = []
|
147 |
+
|
148 |
+
for para in doc.paragraphs:
|
149 |
+
full_text.append(para.text)
|
150 |
+
|
151 |
+
# Extract tables if present
|
152 |
+
for table in doc.tables:
|
153 |
+
for row in table.rows:
|
154 |
+
row_text = [cell.text for cell in row.cells]
|
155 |
+
full_text.append(" | ".join(row_text))
|
156 |
+
|
157 |
+
return "\n\n".join(full_text)
|
158 |
+
|
159 |
+
def _extract_csv(self, file_path: Path) -> str:
|
160 |
+
"""Convert CSV to structured text"""
|
161 |
+
df = pd.read_csv(file_path)
|
162 |
+
return df.to_string()
|
163 |
+
|
164 |
+
def _extract_json(self, file_path: Path) -> str:
|
165 |
+
"""Convert JSON to readable text"""
|
166 |
+
with open(file_path) as f:
|
167 |
+
data = json.load(f)
|
168 |
+
return json.dumps(data, indent=2)
|
169 |
+
|
170 |
+
def _extract_html(self, file_path: Path) -> str:
|
171 |
+
"""Extract text from HTML with structure preservation"""
|
172 |
+
with open(file_path) as f:
|
173 |
+
soup = BeautifulSoup(f, 'html.parser')
|
174 |
+
|
175 |
+
# Remove script and style elements
|
176 |
+
for script in soup(["script", "style"]):
|
177 |
+
script.decompose()
|
178 |
+
|
179 |
+
text = soup.get_text(separator='\n')
|
180 |
+
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
181 |
+
return "\n\n".join(lines)
|
182 |
+
|
183 |
+
def _generate_metadata(
|
184 |
+
self,
|
185 |
+
file_path: Path,
|
186 |
+
content: str,
|
187 |
+
additional_metadata: Optional[Dict] = None
|
188 |
+
) -> Dict:
|
189 |
+
"""Generate comprehensive metadata"""
|
190 |
+
file_stat = file_path.stat()
|
191 |
+
|
192 |
+
metadata = {
|
193 |
+
'filename': file_path.name,
|
194 |
+
'file_type': file_path.suffix,
|
195 |
+
'file_size': file_stat.st_size,
|
196 |
+
'created_at': datetime.fromtimestamp(file_stat.st_ctime),
|
197 |
+
'modified_at': datetime.fromtimestamp(file_stat.st_mtime),
|
198 |
+
'content_hash': self._calculate_hash(content),
|
199 |
+
'mime_type': magic.from_file(str(file_path), mime=True),
|
200 |
+
'word_count': len(content.split()),
|
201 |
+
'character_count': len(content),
|
202 |
+
'processing_timestamp': datetime.now().isoformat()
|
203 |
+
}
|
204 |
+
|
205 |
+
if additional_metadata:
|
206 |
+
metadata.update(additional_metadata)
|
207 |
+
|
208 |
+
return metadata
|
209 |
+
|
210 |
+
def _generate_statistics(self, content: str, chunks: List[str]) -> Dict:
|
211 |
+
"""Generate document statistics"""
|
212 |
+
return {
|
213 |
+
'total_chunks': len(chunks),
|
214 |
+
'average_chunk_size': sum(len(chunk) for chunk in chunks) / len(chunks),
|
215 |
+
'token_estimate': len(content.split()),
|
216 |
+
'unique_words': len(set(content.lower().split())),
|
217 |
+
'sentences': len([s for s in content.split('.') if s.strip()]),
|
218 |
+
}
|
219 |
+
|
220 |
+
def _calculate_hash(self, text: str) -> str:
|
221 |
+
"""Calculate SHA-256 hash of text"""
|
222 |
+
return hashlib.sha256(text.encode()).hexdigest()
|
223 |
+
|
224 |
+
async def batch_process(
|
225 |
+
self,
|
226 |
+
file_paths: List[Union[str, Path]],
|
227 |
+
parallel: bool = True
|
228 |
+
) -> Dict[str, Dict]:
|
229 |
+
"""
|
230 |
+
Process multiple documents in parallel
|
231 |
+
"""
|
232 |
+
results = {}
|
233 |
+
|
234 |
+
if parallel:
|
235 |
+
threads = []
|
236 |
+
for file_path in file_paths:
|
237 |
+
thread = threading.Thread(
|
238 |
+
target=self._process_and_store,
|
239 |
+
args=(file_path, results)
|
240 |
+
)
|
241 |
+
threads.append(thread)
|
242 |
+
thread.start()
|
243 |
+
|
244 |
+
for thread in threads:
|
245 |
+
thread.join()
|
246 |
+
else:
|
247 |
+
for file_path in file_paths:
|
248 |
+
await self._process_and_store(file_path, results)
|
249 |
+
|
250 |
+
return results
|
251 |
+
|
252 |
+
async def _process_and_store(
|
253 |
+
self,
|
254 |
+
file_path: Union[str, Path],
|
255 |
+
results: Dict
|
256 |
+
):
|
257 |
+
"""Process a single document and store results"""
|
258 |
+
try:
|
259 |
+
result = await self.process_document(file_path)
|
260 |
+
results[str(file_path)] = result
|
261 |
+
except Exception as e:
|
262 |
+
results[str(file_path)] = {'error': str(e)}
|
src/vctorstores/__init__.py
DELETED
File without changes
|
src/vectorstores/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base_vectorstore import BaseVectorStore
|
2 |
+
|
3 |
+
__all__ = ['BaseVectorStore']
|
src/vectorstores/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (305 Bytes). View file
|
|
src/vectorstores/__pycache__/base_vectorstore.cpython-312.pyc
ADDED
Binary file (1.62 kB). View file
|
|
src/vectorstores/__pycache__/chroma_vectorstore.cpython-312.pyc
ADDED
Binary file (2.97 kB). View file
|
|
src/{vctorstores → vectorstores}/base_vectorstore.py
RENAMED
File without changes
|
src/{vctorstores → vectorstores}/chroma_vectorstore.py
RENAMED
File without changes
|