Spaces:
Sleeping
Sleeping
File size: 7,282 Bytes
70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 42c00ab 70d06c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
#!/usr/bin/env python
# coding: utf-8
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
# Remove ChatGroq import
# from langchain_groq import ChatGroq
# Add ChatGoogleGenerativeAI import
from langchain_google_genai import ChatGoogleGenerativeAI
import os # Add os import for getenv
from langgraph.graph import StateGraph
from llmlingua import PromptCompressor
# Import get_model which now handles Gemini
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
from langgraph.checkpoint.sqlite import SqliteSaver
# ... (rest of the imports and llm_lingua functions remain the same)
# Requires ~2GB of RAM
def get_llm_lingua(compress_method:str = "llm_lingua2"):
# Requires ~2GB memory
if compress_method == "llm_lingua2":
llm_lingua2 = PromptCompressor(
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
use_llmlingua2=True,
device_map="cpu"
)
return llm_lingua2
# Requires ~8GB memory
elif compress_method == "llm_lingua":
llm_lingua = PromptCompressor(
model_name="microsoft/phi-2",
device_map="cpu"
)
return llm_lingua
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'")
def compress(state: DocProcessorState, config: ConfigSchema):
"""
This node compresses last processing result for each doc using llm_lingua
"""
doc_process_histories = state["docs_in_processing"]
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2")
for doc_process_history in doc_process_histories:
doc_process_history.append(llm_lingua.compress_prompt(
doc = str(doc_process_history[-1]),
rate=config["configurable"].get("compress_rate") or 0.33,
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ',']
)["compressed_prompt"]
)
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
# Update default model
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
"""
This node summarizes all docs in state["valid_docs"]
"""
prompt = """You are a 3GPP standardization expert.
Summarize the provided document in simple technical English for other experts in the field.
Document:
{document}"""
sysmsg = ChatPromptTemplate.from_messages([
("system", prompt)
])
# Update default model name
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
doc_process_histories = state["docs_in_processing"]
# Use get_model to handle instantiation
llm_summarize = get_model(model)
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
for doc_process_history in doc_process_histories:
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])}))
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
# Update default model
def custom_process(state: DocProcessorState):
"""
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
processing_model : the LLM which will perform the processing
context : the previous processing results to send as context to the LLM
user_prompt : the prompt/task which will be appended to the context before sending to the LLM
"""
processing_params = state["process_steps"][state["current_process_step"]]
# Update default model name
model = processing_params.get("processing_model") or "gemini-2.0-flash"
user_prompt = processing_params["prompt"]
context = processing_params.get("context") or [0]
doc_process_histories = state["docs_in_processing"]
if not isinstance(context, list):
context = [context]
# Use get_model
processing_chain = get_model(model=model) | StrOutputParser()
for doc_process_history in doc_process_histories:
context_str = ""
for i, context_element in enumerate(context):
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n"
doc_process_history.append(processing_chain.invoke(context_str + user_prompt))
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
# ... (rest of the file remains the same)
def final(state: DocProcessorState):
"""
A node to store the final results of processing in the 'valid_docs' field
"""
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]}
# TODO : remove this node and use conditional entry point instead
def get_process_steps(state: DocProcessorState, config: ConfigSchema):
"""
Dummy node
"""
# if not process_steps:
# process_steps = eval(input("Enter processing steps: "))
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]}
def next_processor_step(state: DocProcessorState):
"""
Conditional edge function to go to next processing step
"""
process_steps = state["process_steps"]
if state["current_process_step"] < len(process_steps):
step = process_steps[state["current_process_step"]]
if isinstance(step, dict):
step = "custom"
else:
step = "final"
return step
def build_data_processor_graph(memory):
"""
Builds the data processor graph
"""
#with SqliteSaver.from_conn_string(":memory:") as memory :
graph_builder_doc_processor = StateGraph(DocProcessorState)
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps)
graph_builder_doc_processor.add_node("summarize", summarize_docs)
graph_builder_doc_processor.add_node("compress", compress)
graph_builder_doc_processor.add_node("custom", custom_process)
graph_builder_doc_processor.add_node("final", final)
graph_builder_doc_processor.add_edge("__start__", "get_process_steps")
graph_builder_doc_processor.add_conditional_edges(
"get_process_steps",
next_processor_step,
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"summarize",
next_processor_step,
{"compress" : "compress", "final": "final", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"compress",
next_processor_step,
{"summarize" : "summarize", "final": "final", "custom" : "custom"}
)
graph_builder_doc_processor.add_conditional_edges(
"custom",
next_processor_step,
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"}
)
graph_builder_doc_processor.add_edge("final", "__end__")
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory)
return graph_doc_processor |