|
|
|
|
|
|
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_groq import ChatGroq |
|
from langgraph.graph import StateGraph |
|
from llmlingua import PromptCompressor |
|
|
|
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_lingua(compress_method:str = "llm_lingua2"): |
|
|
|
|
|
if compress_method == "llm_lingua2": |
|
llm_lingua2 = PromptCompressor( |
|
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank", |
|
use_llmlingua2=True, |
|
device_map="cpu" |
|
) |
|
return llm_lingua2 |
|
|
|
|
|
elif compress_method == "llm_lingua": |
|
llm_lingua = PromptCompressor( |
|
model_name="microsoft/phi-2", |
|
device_map="cpu" |
|
) |
|
return llm_lingua |
|
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'") |
|
|
|
|
|
|
|
def compress(state: DocProcessorState, config: ConfigSchema): |
|
""" |
|
This node compresses last processing result for each doc using llm_lingua |
|
""" |
|
doc_process_histories = state["docs_in_processing"] |
|
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2") |
|
for doc_process_history in doc_process_histories: |
|
doc_process_history.append(llm_lingua.compress_prompt( |
|
doc = str(doc_process_history[-1]), |
|
rate=config["configurable"].get("compress_rate") or 0.33, |
|
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ','] |
|
)["compressed_prompt"] |
|
) |
|
|
|
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1} |
|
|
|
def summarize_docs(state: DocProcessorState, config: ConfigSchema): |
|
""" |
|
This node summarizes all docs in state["valid_docs"] |
|
""" |
|
|
|
prompt = """You are a 3GPP standardization expert. |
|
Summarize the provided document in simple technical English for other experts in the field. |
|
|
|
Document: |
|
{document}""" |
|
sysmsg = ChatPromptTemplate.from_messages([ |
|
("system", prompt) |
|
]) |
|
model = config["configurable"].get("summarize_model") or "mixtral-8x7b-32768" |
|
doc_process_histories = state["docs_in_processing"] |
|
if model == "gpt-4o": |
|
llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") |
|
else: |
|
llm_summarize = ChatGroq(model=model) |
|
summarize_chain = sysmsg | llm_summarize | StrOutputParser() |
|
|
|
for doc_process_history in doc_process_histories: |
|
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])})) |
|
|
|
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1} |
|
|
|
def custom_process(state: DocProcessorState): |
|
""" |
|
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]] |
|
processing_model : the LLM which will perform the processing |
|
context : the previous processing results to send as context to the LLM |
|
user_prompt : the prompt/task which will be appended to the context before sending to the LLM |
|
""" |
|
|
|
processing_params = state["process_steps"][state["current_process_step"]] |
|
model = processing_params.get("processing_model") or "mixtral-8x7b-32768" |
|
user_prompt = processing_params["prompt"] |
|
context = processing_params.get("context") or [0] |
|
doc_process_histories = state["docs_in_processing"] |
|
if not isinstance(context, list): |
|
context = [context] |
|
|
|
processing_chain = get_model(model=model) | StrOutputParser() |
|
|
|
for doc_process_history in doc_process_histories: |
|
context_str = "" |
|
for i, context_element in enumerate(context): |
|
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n" |
|
doc_process_history.append(processing_chain.invoke(context_str + user_prompt)) |
|
|
|
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1} |
|
|
|
def final(state: DocProcessorState): |
|
""" |
|
A node to store the final results of processing in the 'valid_docs' field |
|
""" |
|
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]} |
|
|
|
|
|
def get_process_steps(state: DocProcessorState, config: ConfigSchema): |
|
""" |
|
Dummy node |
|
""" |
|
|
|
|
|
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]} |
|
|
|
|
|
def next_processor_step(state: DocProcessorState): |
|
""" |
|
Conditional edge function to go to next processing step |
|
""" |
|
process_steps = state["process_steps"] |
|
if state["current_process_step"] < len(process_steps): |
|
step = process_steps[state["current_process_step"]] |
|
if isinstance(step, dict): |
|
step = "custom" |
|
else: |
|
step = "final" |
|
|
|
return step |
|
|
|
|
|
def build_data_processor_graph(memory): |
|
""" |
|
Builds the data processor graph |
|
""" |
|
|
|
graph_builder_doc_processor = StateGraph(DocProcessorState) |
|
|
|
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps) |
|
graph_builder_doc_processor.add_node("summarize", summarize_docs) |
|
graph_builder_doc_processor.add_node("compress", compress) |
|
graph_builder_doc_processor.add_node("custom", custom_process) |
|
graph_builder_doc_processor.add_node("final", final) |
|
|
|
graph_builder_doc_processor.add_edge("__start__", "get_process_steps") |
|
graph_builder_doc_processor.add_conditional_edges( |
|
"get_process_steps", |
|
next_processor_step, |
|
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"} |
|
) |
|
graph_builder_doc_processor.add_conditional_edges( |
|
"summarize", |
|
next_processor_step, |
|
{"compress" : "compress", "final": "final", "custom" : "custom"} |
|
) |
|
graph_builder_doc_processor.add_conditional_edges( |
|
"compress", |
|
next_processor_step, |
|
{"summarize" : "summarize", "final": "final", "custom" : "custom"} |
|
) |
|
graph_builder_doc_processor.add_conditional_edges( |
|
"custom", |
|
next_processor_step, |
|
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"} |
|
) |
|
graph_builder_doc_processor.add_edge("final", "__end__") |
|
|
|
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory) |
|
return graph_doc_processor |