bjhk / ki_gen /data_processor.py
heymenn's picture
Upload 15 files
6aaddef verified
raw
history blame
7.19 kB
#!/usr/bin/env python
# coding: utf-8
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph
from llmlingua import PromptCompressor
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
# compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
## Before that, you need to pip install optimum auto-gptq
# llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
# Requires ~2GB of RAM
def get_llm_lingua(compress_method:str = "llm_lingua2"):
# Requires ~2GB memory
if compress_method == "llm_lingua2":
llm_lingua2 = PromptCompressor(
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
use_llmlingua2=True,
device_map="cpu"
)
return llm_lingua2
# Requires ~8GB memory
elif compress_method == "llm_lingua":
llm_lingua = PromptCompressor(
model_name="microsoft/phi-2",
device_map="cpu"
)
return llm_lingua
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
def compress(state: DocProcessorState, config: ConfigSchema):
"""
This node compresses last processing result for each doc using llm_lingua
"""
doc_process_histories = state["docs_in_processing"]
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
for doc_process_history in doc_process_histories:
doc_process_history.append(llm_lingua.compress_prompt(
doc = str(doc_process_history[-1]),
rate=config["configurable"].get("compress_rate") or 0.33,
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
)["compressed_prompt"]
)
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
"""
This node summarizes all docs in state["valid_docs"]
"""
prompt = """You are a 3GPP standardization expert.
Summarize the provided document in simple technical English for other experts in the field.
Document:
{document}"""
sysmsg = ChatPromptTemplate.from_messages([
("system", prompt)
])
model = config["configurable"].get("summarize_model") or "mixtral-8x7b-32768"
doc_process_histories = state["docs_in_processing"]
if model == "gpt-4o":
llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
else:
llm_summarize = ChatGroq(model=model)
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
for doc_process_history in doc_process_histories:
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
def custom_process(state: DocProcessorState):
"""
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
processing_model : the LLM which will perform the processing
context : the previous processing results to send as context to the LLM
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
"""
processing_params = state["process_steps"][state["current_process_step"]]
model = processing_params.get("processing_model") or "mixtral-8x7b-32768"
user_prompt = processing_params["prompt"]
context = processing_params.get("context") or [0]
doc_process_histories = state["docs_in_processing"]
if not isinstance(context, list):
context = [context]
processing_chain = get_model(model=model) | StrOutputParser()
for doc_process_history in doc_process_histories:
context_str = ""
for i, context_element in enumerate(context):
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
def final(state: DocProcessorState):
"""
A node to store the final results of processing in the 'valid_docs' field
"""
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
# TODO : remove this node and use conditional entry point instead
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
"""
Dummy node
"""
# if not process_steps:
# process_steps = eval(input("Enter processing steps: "))
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
def next_processor_step(state: DocProcessorState):
"""
Conditional edge function to go to next processing step
"""
process_steps = state["process_steps"]
if state["current_process_step"] < len(process_steps):
step = process_steps[state["current_process_step"]]
if isinstance(step, dict):
step = "custom"
else:
step = "final"
return step
def build_data_processor_graph(memory):
"""
Builds the data processor graph
"""
graph_builder_doc_processor = StateGraph(DocProcessorState)
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
graph_builder_doc_processor.add_node("summarize", summarize_docs)
graph_builder_doc_processor.add_node("compress", compress)
graph_builder_doc_processor.add_node("custom", custom_process)
graph_builder_doc_processor.add_node("final", final)
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
graph_builder_doc_processor.add_conditional_edges(
"get_process_steps",
next_processor_step,
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"summarize",
next_processor_step,
{"compress" : "compress", "final": "final", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"compress",
next_processor_step,
{"summarize" : "summarize", "final": "final", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"custom",
next_processor_step,
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
)
graph_builder_doc_processor.add_edge("final", "__end__")
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
return graph_doc_processor