File size: 7,186 Bytes
6aaddef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
#!/usr/bin/env python
# coding: utf-8
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph
from llmlingua import PromptCompressor
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
# compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
## Before that, you need to pip install optimum auto-gptq
# llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
# Requires ~2GB of RAM
def get_llm_lingua(compress_method:str = "llm_lingua2"):
# Requires ~2GB memory
if compress_method == "llm_lingua2":
llm_lingua2 = PromptCompressor(
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
use_llmlingua2=True,
device_map="cpu"
)
return llm_lingua2
# Requires ~8GB memory
elif compress_method == "llm_lingua":
llm_lingua = PromptCompressor(
model_name="microsoft/phi-2",
device_map="cpu"
)
return llm_lingua
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
def compress(state: DocProcessorState, config: ConfigSchema):
"""
This node compresses last processing result for each doc using llm_lingua
"""
doc_process_histories = state["docs_in_processing"]
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
for doc_process_history in doc_process_histories:
doc_process_history.append(llm_lingua.compress_prompt(
doc = str(doc_process_history[-1]),
rate=config["configurable"].get("compress_rate") or 0.33,
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
)["compressed_prompt"]
)
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
"""
This node summarizes all docs in state["valid_docs"]
"""
prompt = """You are a 3GPP standardization expert.
Summarize the provided document in simple technical English for other experts in the field.
Document:
{document}"""
sysmsg = ChatPromptTemplate.from_messages([
("system", prompt)
])
model = config["configurable"].get("summarize_model") or "mixtral-8x7b-32768"
doc_process_histories = state["docs_in_processing"]
if model == "gpt-4o":
llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
else:
llm_summarize = ChatGroq(model=model)
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
for doc_process_history in doc_process_histories:
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
def custom_process(state: DocProcessorState):
"""
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
processing_model : the LLM which will perform the processing
context : the previous processing results to send as context to the LLM
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
"""
processing_params = state["process_steps"][state["current_process_step"]]
model = processing_params.get("processing_model") or "mixtral-8x7b-32768"
user_prompt = processing_params["prompt"]
context = processing_params.get("context") or [0]
doc_process_histories = state["docs_in_processing"]
if not isinstance(context, list):
context = [context]
processing_chain = get_model(model=model) | StrOutputParser()
for doc_process_history in doc_process_histories:
context_str = ""
for i, context_element in enumerate(context):
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
def final(state: DocProcessorState):
"""
A node to store the final results of processing in the 'valid_docs' field
"""
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
# TODO : remove this node and use conditional entry point instead
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
"""
Dummy node
"""
# if not process_steps:
# process_steps = eval(input("Enter processing steps: "))
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
def next_processor_step(state: DocProcessorState):
"""
Conditional edge function to go to next processing step
"""
process_steps = state["process_steps"]
if state["current_process_step"] < len(process_steps):
step = process_steps[state["current_process_step"]]
if isinstance(step, dict):
step = "custom"
else:
step = "final"
return step
def build_data_processor_graph(memory):
"""
Builds the data processor graph
"""
graph_builder_doc_processor = StateGraph(DocProcessorState)
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
graph_builder_doc_processor.add_node("summarize", summarize_docs)
graph_builder_doc_processor.add_node("compress", compress)
graph_builder_doc_processor.add_node("custom", custom_process)
graph_builder_doc_processor.add_node("final", final)
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
graph_builder_doc_processor.add_conditional_edges(
"get_process_steps",
next_processor_step,
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"summarize",
next_processor_step,
{"compress" : "compress", "final": "final", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"compress",
next_processor_step,
{"summarize" : "summarize", "final": "final", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"custom",
next_processor_step,
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
)
graph_builder_doc_processor.add_edge("final", "__end__")
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
return graph_doc_processor |