Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from langchain_openai import ChatOpenAI | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
# Remove ChatGroq import | |
# from langchain_groq import ChatGroq | |
# Add ChatGoogleGenerativeAI import | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import os # Add os import for getenv | |
from langgraph.graph import StateGraph | |
from llmlingua import PromptCompressor | |
# Import get_model which now handles Gemini | |
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
# ... (rest of the imports and llm_lingua functions remain the same) | |
# Requires ~2GB of RAM | |
def get_llm_lingua(compress_method:str = "llm_lingua2"): | |
# Requires ~2GB memory | |
if compress_method == "llm_lingua2": | |
llm_lingua2 = PromptCompressor( | |
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank", | |
use_llmlingua2=True, | |
device_map="cpu" | |
) | |
return llm_lingua2 | |
# Requires ~8GB memory | |
elif compress_method == "llm_lingua": | |
llm_lingua = PromptCompressor( | |
model_name="microsoft/phi-2", | |
device_map="cpu" | |
) | |
return llm_lingua | |
raise ValueError("Incorrect compression method, should be 'llm_lingua' or 'llm_lingua2'") | |
def compress(state: DocProcessorState, config: ConfigSchema): | |
""" | |
This node compresses last processing result for each doc using llm_lingua | |
""" | |
doc_process_histories = state["docs_in_processing"] | |
llm_lingua = get_llm_lingua(config["configurable"].get("compression_method") or "llm_lingua2") | |
for doc_process_history in doc_process_histories: | |
doc_process_history.append(llm_lingua.compress_prompt( | |
doc = str(doc_process_history[-1]), | |
rate=config["configurable"].get("compress_rate") or 0.33, | |
force_tokens=config["configurable"].get("force_tokens") or ['\n', '?', '.', '!', ','] | |
)["compressed_prompt"] | |
) | |
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1} | |
# Update default model | |
def summarize_docs(state: DocProcessorState, config: ConfigSchema): | |
""" | |
This node summarizes all docs in state["valid_docs"] | |
""" | |
prompt = """You are a 3GPP standardization expert. | |
Summarize the provided document in simple technical English for other experts in the field. | |
Document: | |
{document}""" | |
sysmsg = ChatPromptTemplate.from_messages([ | |
("system", prompt) | |
]) | |
# Update default model name | |
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash" | |
doc_process_histories = state["docs_in_processing"] | |
# Use get_model to handle instantiation | |
llm_summarize = get_model(model) | |
summarize_chain = sysmsg | llm_summarize | StrOutputParser() | |
for doc_process_history in doc_process_histories: | |
doc_process_history.append(summarize_chain.invoke({"document" : str(doc_process_history[-1])})) | |
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1} | |
# Update default model | |
def custom_process(state: DocProcessorState): | |
""" | |
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]] | |
processing_model : the LLM which will perform the processing | |
context : the previous processing results to send as context to the LLM | |
user_prompt : the prompt/task which will be appended to the context before sending to the LLM | |
""" | |
processing_params = state["process_steps"][state["current_process_step"]] | |
# Update default model name | |
model = processing_params.get("processing_model") or "gemini-2.0-flash" | |
user_prompt = processing_params["prompt"] | |
context = processing_params.get("context") or [0] | |
doc_process_histories = state["docs_in_processing"] | |
if not isinstance(context, list): | |
context = [context] | |
# Use get_model | |
processing_chain = get_model(model=model) | StrOutputParser() | |
for doc_process_history in doc_process_histories: | |
context_str = "" | |
for i, context_element in enumerate(context): | |
context_str += f"### TECHNICAL INFORMATION {i+1} \n {doc_process_history[context_element]}\n\n" | |
doc_process_history.append(processing_chain.invoke(context_str + user_prompt)) | |
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1} | |
# ... (rest of the file remains the same) | |
def final(state: DocProcessorState): | |
""" | |
A node to store the final results of processing in the 'valid_docs' field | |
""" | |
return {"valid_docs" : [doc_process_history[-1] for doc_process_history in state["docs_in_processing"]]} | |
# TODO : remove this node and use conditional entry point instead | |
def get_process_steps(state: DocProcessorState, config: ConfigSchema): | |
""" | |
Dummy node | |
""" | |
# if not process_steps: | |
# process_steps = eval(input("Enter processing steps: ")) | |
return {"current_process_step": 0, "docs_in_processing" : [[format_doc(doc)] for doc in state["valid_docs"]]} | |
def next_processor_step(state: DocProcessorState): | |
""" | |
Conditional edge function to go to next processing step | |
""" | |
process_steps = state["process_steps"] | |
if state["current_process_step"] < len(process_steps): | |
step = process_steps[state["current_process_step"]] | |
if isinstance(step, dict): | |
step = "custom" | |
else: | |
step = "final" | |
return step | |
def build_data_processor_graph(memory): | |
""" | |
Builds the data processor graph | |
""" | |
#with SqliteSaver.from_conn_string(":memory:") as memory : | |
graph_builder_doc_processor = StateGraph(DocProcessorState) | |
graph_builder_doc_processor.add_node("get_process_steps", get_process_steps) | |
graph_builder_doc_processor.add_node("summarize", summarize_docs) | |
graph_builder_doc_processor.add_node("compress", compress) | |
graph_builder_doc_processor.add_node("custom", custom_process) | |
graph_builder_doc_processor.add_node("final", final) | |
graph_builder_doc_processor.add_edge("__start__", "get_process_steps") | |
graph_builder_doc_processor.add_conditional_edges( | |
"get_process_steps", | |
next_processor_step, | |
{"compress" : "compress", "final": "final", "summarize": "summarize", "custom" : "custom"} | |
) | |
graph_builder_doc_processor.add_conditional_edges( | |
"summarize", | |
next_processor_step, | |
{"compress" : "compress", "final": "final", "custom" : "custom"} | |
) | |
graph_builder_doc_processor.add_conditional_edges( | |
"compress", | |
next_processor_step, | |
{"summarize" : "summarize", "final": "final", "custom" : "custom"} | |
) | |
graph_builder_doc_processor.add_conditional_edges( | |
"custom", | |
next_processor_step, | |
{"summarize" : "summarize", "final": "final", "compress" : "compress", "custom" : "custom"} | |
) | |
graph_builder_doc_processor.add_edge("final", "__end__") | |
graph_doc_processor = graph_builder_doc_processor.compile(checkpointer=memory) | |
return graph_doc_processor |