Spaces:
Sleeping
Sleeping
Update ki_gen/data_processor.py
Browse files- ki_gen/data_processor.py +20 -16
ki_gen/data_processor.py
CHANGED
@@ -4,22 +4,21 @@
|
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
from langchain_core.output_parsers import StrOutputParser
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
from langgraph.graph import StateGraph
|
9 |
from llmlingua import PromptCompressor
|
10 |
|
11 |
-
|
|
|
12 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
13 |
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
# compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
|
18 |
-
|
19 |
-
## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
|
20 |
-
## Before that, you need to pip install optimum auto-gptq
|
21 |
-
# llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
|
22 |
-
|
23 |
|
24 |
|
25 |
# Requires ~2GB of RAM
|
@@ -61,6 +60,7 @@ def compress(state: DocProcessorState, config: ConfigSchema):
|
|
61 |
|
62 |
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
63 |
|
|
|
64 |
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
|
65 |
"""
|
66 |
This node summarizes all docs in state["valid_docs"]
|
@@ -74,12 +74,11 @@ Document:
|
|
74 |
sysmsg = ChatPromptTemplate.from_messages([
|
75 |
("system", prompt)
|
76 |
])
|
77 |
-
|
|
|
78 |
doc_process_histories = state["docs_in_processing"]
|
79 |
-
|
80 |
-
|
81 |
-
else:
|
82 |
-
llm_summarize = ChatGroq(model=model)
|
83 |
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
|
84 |
|
85 |
for doc_process_history in doc_process_histories:
|
@@ -87,6 +86,7 @@ Document:
|
|
87 |
|
88 |
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
|
89 |
|
|
|
90 |
def custom_process(state: DocProcessorState):
|
91 |
"""
|
92 |
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
|
@@ -96,13 +96,15 @@ def custom_process(state: DocProcessorState):
|
|
96 |
"""
|
97 |
|
98 |
processing_params = state["process_steps"][state["current_process_step"]]
|
99 |
-
|
|
|
100 |
user_prompt = processing_params["prompt"]
|
101 |
context = processing_params.get("context") or [0]
|
102 |
doc_process_histories = state["docs_in_processing"]
|
103 |
if not isinstance(context, list):
|
104 |
context = [context]
|
105 |
|
|
|
106 |
processing_chain = get_model(model=model) | StrOutputParser()
|
107 |
|
108 |
for doc_process_history in doc_process_histories:
|
@@ -113,6 +115,8 @@ def custom_process(state: DocProcessorState):
|
|
113 |
|
114 |
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
115 |
|
|
|
|
|
116 |
def final(state: DocProcessorState):
|
117 |
"""
|
118 |
A node to store the final results of processing in the 'valid_docs' field
|
|
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
from langchain_core.output_parsers import StrOutputParser
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
# Remove ChatGroq import
|
8 |
+
# from langchain_groq import ChatGroq
|
9 |
+
# Add ChatGoogleGenerativeAI import
|
10 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
+
import os # Add os import for getenv
|
12 |
+
|
13 |
from langgraph.graph import StateGraph
|
14 |
from llmlingua import PromptCompressor
|
15 |
|
16 |
+
# Import get_model which now handles Gemini
|
17 |
+
from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
|
18 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
19 |
|
20 |
|
21 |
+
# ... (rest of the imports and llm_lingua functions remain the same)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
# Requires ~2GB of RAM
|
|
|
60 |
|
61 |
return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
62 |
|
63 |
+
# Update default model
|
64 |
def summarize_docs(state: DocProcessorState, config: ConfigSchema):
|
65 |
"""
|
66 |
This node summarizes all docs in state["valid_docs"]
|
|
|
74 |
sysmsg = ChatPromptTemplate.from_messages([
|
75 |
("system", prompt)
|
76 |
])
|
77 |
+
# Update default model name
|
78 |
+
model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
|
79 |
doc_process_histories = state["docs_in_processing"]
|
80 |
+
# Use get_model to handle instantiation
|
81 |
+
llm_summarize = get_model(model)
|
|
|
|
|
82 |
summarize_chain = sysmsg | llm_summarize | StrOutputParser()
|
83 |
|
84 |
for doc_process_history in doc_process_histories:
|
|
|
86 |
|
87 |
return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
|
88 |
|
89 |
+
# Update default model
|
90 |
def custom_process(state: DocProcessorState):
|
91 |
"""
|
92 |
Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
|
|
|
96 |
"""
|
97 |
|
98 |
processing_params = state["process_steps"][state["current_process_step"]]
|
99 |
+
# Update default model name
|
100 |
+
model = processing_params.get("processing_model") or "gemini-2.0-flash"
|
101 |
user_prompt = processing_params["prompt"]
|
102 |
context = processing_params.get("context") or [0]
|
103 |
doc_process_histories = state["docs_in_processing"]
|
104 |
if not isinstance(context, list):
|
105 |
context = [context]
|
106 |
|
107 |
+
# Use get_model
|
108 |
processing_chain = get_model(model=model) | StrOutputParser()
|
109 |
|
110 |
for doc_process_history in doc_process_histories:
|
|
|
115 |
|
116 |
return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
|
117 |
|
118 |
+
# ... (rest of the file remains the same)
|
119 |
+
|
120 |
def final(state: DocProcessorState):
|
121 |
"""
|
122 |
A node to store the final results of processing in the 'valid_docs' field
|