KIG / kig_core /processing.py
heymenn's picture
Update kig_core/processing.py
81baa99 verified
import logging
from typing import List, Dict, Any, Union, Optional
from langchain_core.output_parsers import StrOutputParser
from .config import settings
from .llm_interface import get_llm, invoke_llm
from .prompts import SUMMARIZER_PROMPT
from .graph_operations import format_doc_for_llm # Reuse formatting
# Import llmlingua if compression is used
try:
from llmlingua import PromptCompressor
LLMLINGUA_AVAILABLE = True
except ImportError:
LLMLINGUA_AVAILABLE = False
PromptCompressor = None # Define as None if not available
logger = logging.getLogger(__name__)
_compressor_cache = {}
def get_compressor(method: str) -> Optional['PromptCompressor']:
"""Initializes and caches llmlingua compressor."""
if not LLMLINGUA_AVAILABLE:
logger.warning("LLMLingua not installed, compression unavailable.")
return None
if method not in _compressor_cache:
logger.info(f"Initializing LLMLingua compressor: {method}")
try:
# Adjust model names and params as needed
if method == "llm_lingua2":
model_name = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"
use_llmlingua2 = True
elif method == "llm_lingua":
model_name = "microsoft/phi-2" # Requires ~8GB RAM
use_llmlingua2 = False
else:
logger.error(f"Unsupported compression method: {method}")
return None
_compressor_cache[method] = PromptCompressor(
model_name=model_name,
use_llmlingua2=use_llmlingua2,
device_map="cpu" # Or "cuda" if GPU available
)
except Exception as e:
logger.error(f"Failed to initialize LLMLingua compressor '{method}': {e}", exc_info=True)
return None
return _compressor_cache[method]
def summarize_document(doc_content: str) -> str:
"""Summarizes a single document using the configured LLM."""
logger.debug("Summarizing document...")
try:
summarize_llm = get_llm(settings.summarize_llm_model)
summarize_chain = SUMMARIZER_PROMPT | summarize_llm | StrOutputParser()
summary = invoke_llm(summarize_chain, {"document": doc_content})
logger.debug("Summarization complete.")
return summary
except Exception as e:
logger.error(f"Summarization failed: {e}", exc_info=True)
return f"Error during summarization: {e}" # Return error message instead of failing
def compress_document(doc_content: str) -> str:
"""Compresses a single document using LLMLingua."""
logger.debug(f"Compressing document using method: {settings.compression_method}...")
if not settings.compression_method:
logger.warning("Compression method not configured, skipping.")
return doc_content
compressor = get_compressor(settings.compression_method)
if not compressor:
logger.warning("Compressor not available, skipping compression.")
return doc_content
try:
# Adjust compression parameters as needed
# rate = settings.compress_rate or 0.5
# force_tokens = ['\n', '.', ',', '?', '!'] # Example tokens
# context? instructions? question?
# Simple compression for now:
result = compressor.compress_prompt(doc_content, rate=settings.compress_rate or 0.5)
compressed_text = result.get("compressed_prompt", doc_content)
original_len = len(doc_content.split())
compressed_len = len(compressed_text.split())
logger.debug(f"Compression complete. Original words: {original_len}, Compressed words: {compressed_len}")
return compressed_text
except Exception as e:
logger.error(f"Compression failed: {e}", exc_info=True)
return f"Error during compression: {e}" # Return error message
def process_documents(
docs: List[Dict[str, Any]],
processing_steps: List[Union[str, dict]]
) -> List[str]:
"""Processes a list of documents according to the specified steps."""
logger.info(f"Processing {len(docs)} documents with steps: {processing_steps}")
if not docs:
return []
processed_outputs = []
for i, doc in enumerate(docs):
logger.info(f"Processing document {i+1}/{len(docs)}...")
current_content = format_doc_for_llm(doc) # Start with formatted original doc
for step in processing_steps:
if step == "summarize":
current_content = summarize_document(current_content)
elif step == "compress":
current_content = compress_document(current_content)
elif isinstance(step, dict):
# Placeholder for custom processing steps defined by dicts
logger.warning(f"Custom processing step not implemented: {step}")
# Add logic here if needed: extract params, call specific LLM/function
pass
else:
logger.warning(f"Unknown processing step type: {step}")
processed_outputs.append(current_content) # Add the final processed content for this doc
logger.info("Document processing finished.")
return processed_outputs