#!/usr/bin/env python # coding: utf-8 from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq from langgraph.graph import StateGraph from llmlingua import PromptCompressor from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc # compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200) ## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory. ## Before that, you need to pip install optimum auto-gptq # llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"}) # Requires ~2GB of RAM def get_llm_lingua(compress_method:str = "llm_lingua2"): # Requires ~2GB memory if compress_method == "llm_lingua2": llm_lingua2 = PromptCompressor( model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, device_map="cpu" ) return llm_lingua2 # Requires ~8GB memory elif compress_method == "llm_lingua": llm_lingua = PromptCompressor( model_name="microsoft/phi-2", device_map="cpu" ) return llm_lingua raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'") def compress(state: DocProcessorState, config: ConfigSchema): """ This node compresses last processing result for each doc using llm_lingua """ doc_process_histories = state["docs_in_processing"] llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2") for doc_process_history in doc_process_histories: doc_process_history.append(llm_lingua.compress_prompt( doc = str(doc_process_history[-1]), rate=config["configurable"].get("compress_rate") or 0.33, force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ','] )["compressed_prompt"] ) return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1} def summarize_docs(state: DocProcessorState, config: ConfigSchema): """ This node summarizes all docs in state["valid_docs"] """ prompt = """You are a 3GPP standardization expert. Summarize the provided document in simple technical English for other experts in the field. Document: {document}""" sysmsg = ChatPromptTemplate.from_messages([ ("system", prompt) ]) model = config["configurable"].get("summarize_model") or "mixtral-8x7b-32768" doc_process_histories = state["docs_in_processing"] if model == "gpt-4o": llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") else: llm_summarize = ChatGroq(model=model) summarize_chain = sysmsg | llm_summarize | StrOutputParser() for doc_process_history in doc_process_histories: doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])})) return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1} def custom_process(state: DocProcessorState): """ Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]] processing_model : the LLM which will perform the processing context : the previous processing results to send as context to the LLM user_prompt : the prompt/task which will be appended to the context before sending to the LLM """ processing_params = state["process_steps"][state["current_process_step"]] model = processing_params.get("processing_model") or "mixtral-8x7b-32768" user_prompt = processing_params["prompt"] context = processing_params.get("context") or [0] doc_process_histories = state["docs_in_processing"] if not isinstance(context, list): context = [context] processing_chain = get_model(model=model) | StrOutputParser() for doc_process_history in doc_process_histories: context_str = "" for i, context_element in enumerate(context): context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n" doc_process_history.append(processing_chain.invoke(context_str + user_prompt)) return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1} def final(state: DocProcessorState): """ A node to store the final results of processing in the 'valid_docs' field """ return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]} # TODO : remove this node and use conditional entry point instead def get_process_steps(state: DocProcessorState, config: ConfigSchema): """ Dummy node """ # if not process_steps: # process_steps = eval(input("Enter processing steps: ")) return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]} def next_processor_step(state: DocProcessorState): """ Conditional edge function to go to next processing step """ process_steps = state["process_steps"] if state["current_process_step"] < len(process_steps): step = process_steps[state["current_process_step"]] if isinstance(step, dict): step = "custom" else: step = "final" return step def build_data_processor_graph(memory): """ Builds the data processor graph """ graph_builder_doc_processor = StateGraph(DocProcessorState) graph_builder_doc_processor.add_node("get_process_steps", get_process_steps) graph_builder_doc_processor.add_node("summarize", summarize_docs) graph_builder_doc_processor.add_node("compress", compress) graph_builder_doc_processor.add_node("custom", custom_process) graph_builder_doc_processor.add_node("final", final) graph_builder_doc_processor.add_edge("__start__", "get_process_steps") graph_builder_doc_processor.add_conditional_edges( "get_process_steps", next_processor_step, {"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"} ) graph_builder_doc_processor.add_conditional_edges( "summarize", next_processor_step, {"compress" : "compress", "final": "final", "custom" : "custom"} ) graph_builder_doc_processor.add_conditional_edges( "compress", next_processor_step, {"summarize" : "summarize", "final": "final", "custom" : "custom"} ) graph_builder_doc_processor.add_conditional_edges( "custom", next_processor_step, {"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"} ) graph_builder_doc_processor.add_edge("final", "__end__") graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory) return graph_doc_processor