adrienbrdne commited on
Commit
42c00ab
·
verified ·
1 Parent(s): 19491ad

Update ki_gen/data_processor.py

Browse files
Files changed (1) hide show
  1. ki_gen/data_processor.py +20 -16
ki_gen/data_processor.py CHANGED
@@ -4,22 +4,21 @@
4
  from langchain_openai import ChatOpenAI
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.prompts import ChatPromptTemplate
7
- from langchain_groq import ChatGroq
 
 
 
 
 
8
  from langgraph.graph import StateGraph
9
  from llmlingua import PromptCompressor
10
 
11
- from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
 
12
  from langgraph.checkpoint.sqlite import SqliteSaver
13
 
14
 
15
-
16
-
17
- # compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question="", target_token=200)
18
-
19
- ## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
20
- ## Before that, you need to pip install optimum auto-gptq
21
- # llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
22
-
23
 
24
 
25
  # Requires ~2GB of RAM
@@ -61,6 +60,7 @@ def compress(state: DocProcessorState, config: ConfigSchema):
61
 
62
  return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
63
 
 
64
  def summarize_docs(state: DocProcessorState, config: ConfigSchema):
65
  """
66
  This node summarizes all docs in state["valid_docs"]
@@ -74,12 +74,11 @@ Document:
74
  sysmsg = ChatPromptTemplate.from_messages([
75
  ("system", prompt)
76
  ])
77
- model = config["configurable"].get("summarize_model") or "deepseek-r1-distill-llama-70b"
 
78
  doc_process_histories = state["docs_in_processing"]
79
- if model == "gpt-4o":
80
- llm_summarize = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
81
- else:
82
- llm_summarize = ChatGroq(model=model)
83
  summarize_chain = sysmsg | llm_summarize | StrOutputParser()
84
 
85
  for doc_process_history in doc_process_histories:
@@ -87,6 +86,7 @@ Document:
87
 
88
  return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
89
 
 
90
  def custom_process(state: DocProcessorState):
91
  """
92
  Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
@@ -96,13 +96,15 @@ def custom_process(state: DocProcessorState):
96
  """
97
 
98
  processing_params = state["process_steps"][state["current_process_step"]]
99
- model = processing_params.get("processing_model") or "deepseek-r1-distill-llama-70b"
 
100
  user_prompt = processing_params["prompt"]
101
  context = processing_params.get("context") or [0]
102
  doc_process_histories = state["docs_in_processing"]
103
  if not isinstance(context, list):
104
  context = [context]
105
 
 
106
  processing_chain = get_model(model=model) | StrOutputParser()
107
 
108
  for doc_process_history in doc_process_histories:
@@ -113,6 +115,8 @@ def custom_process(state: DocProcessorState):
113
 
114
  return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
115
 
 
 
116
  def final(state: DocProcessorState):
117
  """
118
  A node to store the final results of processing in the 'valid_docs' field
 
4
  from langchain_openai import ChatOpenAI
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.prompts import ChatPromptTemplate
7
+ # Remove ChatGroq import
8
+ # from langchain_groq import ChatGroq
9
+ # Add ChatGoogleGenerativeAI import
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ import os # Add os import for getenv
12
+
13
  from langgraph.graph import StateGraph
14
  from llmlingua import PromptCompressor
15
 
16
+ # Import get_model which now handles Gemini
17
+ from ki_gen.utils import ConfigSchema, DocProcessorState, get_model, format_doc
18
  from langgraph.checkpoint.sqlite import SqliteSaver
19
 
20
 
21
+ # ... (rest of the imports and llm_lingua functions remain the same)
 
 
 
 
 
 
 
22
 
23
 
24
  # Requires ~2GB of RAM
 
60
 
61
  return {"docs_in_processing": doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
62
 
63
+ # Update default model
64
  def summarize_docs(state: DocProcessorState, config: ConfigSchema):
65
  """
66
  This node summarizes all docs in state["valid_docs"]
 
74
  sysmsg = ChatPromptTemplate.from_messages([
75
  ("system", prompt)
76
  ])
77
+ # Update default model name
78
+ model = config["configurable"].get("summarize_model") or "gemini-2.0-flash"
79
  doc_process_histories = state["docs_in_processing"]
80
+ # Use get_model to handle instantiation
81
+ llm_summarize = get_model(model)
 
 
82
  summarize_chain = sysmsg | llm_summarize | StrOutputParser()
83
 
84
  for doc_process_history in doc_process_histories:
 
86
 
87
  return {"docs_in_processing": doc_process_histories, "current_process_step": state["current_process_step"] + 1}
88
 
89
+ # Update default model
90
  def custom_process(state: DocProcessorState):
91
  """
92
  Custom processing step, params are stored in a dict in state["process_steps"][state["current_process_step"]]
 
96
  """
97
 
98
  processing_params = state["process_steps"][state["current_process_step"]]
99
+ # Update default model name
100
+ model = processing_params.get("processing_model") or "gemini-2.0-flash"
101
  user_prompt = processing_params["prompt"]
102
  context = processing_params.get("context") or [0]
103
  doc_process_histories = state["docs_in_processing"]
104
  if not isinstance(context, list):
105
  context = [context]
106
 
107
+ # Use get_model
108
  processing_chain = get_model(model=model) | StrOutputParser()
109
 
110
  for doc_process_history in doc_process_histories:
 
115
 
116
  return {"docs_in_processing" : doc_process_histories, "current_process_step" : state["current_process_step"] + 1}
117
 
118
+ # ... (rest of the file remains the same)
119
+
120
  def final(state: DocProcessorState):
121
  """
122
  A node to store the final results of processing in the 'valid_docs' field