File size: 3,418 Bytes
fe062b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# document_manager.py

import logging
import hashlib
import time
from typing import List, Optional, Any

import chromadb
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PIL import Image
import torch

from config import ResearchConfig

logger = logging.getLogger(__name__)

class QuantumDocumentManager:
    """
    Manages creation of Chroma collections from raw document texts.
    """
    def __init__(self) -> None:
        try:
            self.client = chromadb.PersistentClient(path=ResearchConfig.CHROMA_PATH)
            logger.info("Initialized PersistentClient for Chroma.")
        except Exception as e:
            logger.exception("Error initializing PersistentClient; falling back to in-memory client.")
            self.client = chromadb.Client()
        self.embeddings = OpenAIEmbeddings(
            model="text-embedding-3-large",
            dimensions=ResearchConfig.EMBEDDING_DIMENSIONS
        )
    
    def create_collection(self, documents: List[str], collection_name: str) -> Any:
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=ResearchConfig.CHUNK_SIZE,
            chunk_overlap=ResearchConfig.CHUNK_OVERLAP,
            separators=["\n\n", "\n", "|||"]
        )
        try:
            docs = splitter.create_documents(documents)
            logger.info(f"Created {len(docs)} document chunks for collection '{collection_name}'.")
        except Exception as e:
            logger.exception("Error during document splitting.")
            raise e
        return chromadb.Chroma.from_documents(
            documents=docs,
            embedding=self.embeddings,
            client=self.client,
            collection_name=collection_name,
            ids=[self._document_id(doc.page_content) for doc in docs]
        )
    
    def _document_id(self, content: str) -> str:
        return f"{hashlib.sha256(content.encode()).hexdigest()[:16]}-{int(time.time())}"

class ExtendedQuantumDocumentManager(QuantumDocumentManager):
    """
    Extends QuantumDocumentManager with multi-modal (image) document handling.
    Uses dependency injection for CLIP components.
    """
    def __init__(self, clip_model: Any, clip_processor: Any) -> None:
        super().__init__()
        self.clip_model = clip_model
        self.clip_processor = clip_processor

    def create_image_collection(self, image_paths: List[str]) -> Optional[Any]:
        embeddings = []
        valid_images = []
        for img_path in image_paths:
            try:
                image = Image.open(img_path)
                inputs = self.clip_processor(images=image, return_tensors="pt")
                with torch.no_grad():
                    emb = self.clip_model.get_image_features(**inputs)
                embeddings.append(emb.numpy())
                valid_images.append(img_path)
            except FileNotFoundError:
                logger.warning(f"Image file not found: {img_path}. Skipping.")
            except Exception as e:
                logger.exception(f"Error processing image {img_path}: {str(e)}")
        if not embeddings:
            logger.error("No valid images found for image collection.")
            return None
        return chromadb.Chroma.from_embeddings(
            embeddings=embeddings,
            documents=valid_images,
            collection_name="neuro_images"
        )