File size: 5,216 Bytes
ec6d5f9
3ce7918
ec6d5f9
 
 
5fdee62
ec6d5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fdee62
ec6d5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81baa99
 
 
 
 
 
 
 
ec6d5f9
81baa99
 
 
ec6d5f9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import logging
from typing import List, Dict, Any, Union, Optional
from langchain_core.output_parsers import StrOutputParser

from .config import settings
from .llm_interface import get_llm, invoke_llm
from .prompts import SUMMARIZER_PROMPT
from .graph_operations import format_doc_for_llm # Reuse formatting

# Import llmlingua if compression is used
try:
    from llmlingua import PromptCompressor
    LLMLINGUA_AVAILABLE = True
except ImportError:
    LLMLINGUA_AVAILABLE = False
    PromptCompressor = None # Define as None if not available

logger = logging.getLogger(__name__)

_compressor_cache = {}

def get_compressor(method: str) -> Optional['PromptCompressor']:
    """Initializes and caches llmlingua compressor."""
    if not LLMLINGUA_AVAILABLE:
        logger.warning("LLMLingua not installed, compression unavailable.")
        return None
    if method not in _compressor_cache:
        logger.info(f"Initializing LLMLingua compressor: {method}")
        try:
            # Adjust model names and params as needed
            if method == "llm_lingua2":
                model_name = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"
                use_llmlingua2 = True
            elif method == "llm_lingua":
                 model_name = "microsoft/phi-2" # Requires ~8GB RAM
                 use_llmlingua2 = False
            else:
                logger.error(f"Unsupported compression method: {method}")
                return None

            _compressor_cache[method] = PromptCompressor(
                model_name=model_name,
                use_llmlingua2=use_llmlingua2,
                device_map="cpu" # Or "cuda" if GPU available
            )
        except Exception as e:
            logger.error(f"Failed to initialize LLMLingua compressor '{method}': {e}", exc_info=True)
            return None
    return _compressor_cache[method]


def summarize_document(doc_content: str) -> str:
    """Summarizes a single document using the configured LLM."""
    logger.debug("Summarizing document...")
    try:
        summarize_llm = get_llm(settings.summarize_llm_model)
        summarize_chain = SUMMARIZER_PROMPT | summarize_llm | StrOutputParser()
        summary = invoke_llm(summarize_chain, {"document": doc_content})
        logger.debug("Summarization complete.")
        return summary
    except Exception as e:
        logger.error(f"Summarization failed: {e}", exc_info=True)
        return f"Error during summarization: {e}" # Return error message instead of failing


def compress_document(doc_content: str) -> str:
    """Compresses a single document using LLMLingua."""
    logger.debug(f"Compressing document using method: {settings.compression_method}...")
    if not settings.compression_method:
        logger.warning("Compression method not configured, skipping.")
        return doc_content

    compressor = get_compressor(settings.compression_method)
    if not compressor:
        logger.warning("Compressor not available, skipping compression.")
        return doc_content

    try:
        # Adjust compression parameters as needed
        # rate = settings.compress_rate or 0.5
        # force_tokens = ['\n', '.', ',', '?', '!'] # Example tokens
        # context? instructions? question?

        # Simple compression for now:
        result = compressor.compress_prompt(doc_content, rate=settings.compress_rate or 0.5)
        compressed_text = result.get("compressed_prompt", doc_content)

        original_len = len(doc_content.split())
        compressed_len = len(compressed_text.split())
        logger.debug(f"Compression complete. Original words: {original_len}, Compressed words: {compressed_len}")
        return compressed_text
    except Exception as e:
        logger.error(f"Compression failed: {e}", exc_info=True)
        return f"Error during compression: {e}" # Return error message


def process_documents(
    docs: List[Dict[str, Any]],
    processing_steps: List[Union[str, dict]]
) -> List[str]:
    """Processes a list of documents according to the specified steps."""
    logger.info(f"Processing {len(docs)} documents with steps: {processing_steps}")
    if not docs:
        return []

    processed_outputs = []
    for i, doc in enumerate(docs):
        logger.info(f"Processing document {i+1}/{len(docs)}...")
        current_content = format_doc_for_llm(doc) # Start with formatted original doc

        for step in processing_steps:
            if step == "summarize":
                current_content = summarize_document(current_content)
            elif step == "compress":
                 current_content = compress_document(current_content)
            elif isinstance(step, dict):
                # Placeholder for custom processing steps defined by dicts
                logger.warning(f"Custom processing step not implemented: {step}")
                # Add logic here if needed: extract params, call specific LLM/function
                pass
            else:
                logger.warning(f"Unknown processing step type: {step}")

        processed_outputs.append(current_content) # Add the final processed content for this doc

    logger.info("Document processing finished.")
    return processed_outputs