Spaces:
Running
Running
""" | |
Langchain agent | |
""" | |
from typing import Generator, Dict, Optional, Literal, TypedDict, List | |
from dotenv import load_dotenv | |
from langchain_groq import ChatGroq | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.messages import BaseMessage | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableSerializable | |
from langchain_core.output_parsers import StrOutputParser | |
from .prompts import SYSTEM_PROMPT, REFERENCE_SYSTEM_PROMPT | |
load_dotenv() | |
valid_model_names = Literal[ | |
'llama3-70b-8192', | |
'llama3-8b-8192', | |
'gemma-7b-it', | |
'gemma2-9b-it', | |
'mixtral-8x7b-32768' | |
] | |
class ResponseChunk(TypedDict): | |
delta: str | |
response_type: Literal['intermediate', 'output'] | |
metadata: Dict = {} | |
class MOAgent: | |
def __init__( | |
self, | |
main_agent: RunnableSerializable[Dict, str], | |
layer_agent: RunnableSerializable[Dict, Dict], | |
reference_system_prompt: Optional[str] = None, | |
cycles: Optional[int] = None, | |
chat_memory: Optional[ConversationBufferMemory] = None | |
) -> None: | |
self.reference_system_prompt = reference_system_prompt or REFERENCE_SYSTEM_PROMPT | |
self.main_agent = main_agent | |
self.layer_agent = layer_agent | |
self.cycles = cycles or 1 | |
self.chat_memory = chat_memory or ConversationBufferMemory( | |
memory_key="messages", | |
return_messages=True | |
) | |
def concat_response( | |
inputs: Dict[str, str], | |
reference_system_prompt: Optional[str] = None | |
): | |
reference_system_prompt = reference_system_prompt or REFERENCE_SYSTEM_PROMPT | |
responses = "" | |
res_list = [] | |
for i, out in enumerate(inputs.values()): | |
responses += f"{i}. {out}\n" | |
res_list.append(out) | |
formatted_prompt = reference_system_prompt.format(responses=responses) | |
return { | |
'formatted_response': formatted_prompt, | |
'responses': res_list | |
} | |
def from_config( | |
cls, | |
main_model: Optional[valid_model_names] = 'llama3-70b-8192', | |
system_prompt: Optional[str] = None, | |
cycles: int = 1, | |
layer_agent_config: Optional[Dict] = None, | |
reference_system_prompt: Optional[str] = None, | |
**main_model_kwargs | |
): | |
reference_system_prompt = reference_system_prompt or REFERENCE_SYSTEM_PROMPT | |
system_prompt = system_prompt or SYSTEM_PROMPT | |
layer_agent = MOAgent._configure_layer_agent(layer_agent_config) | |
main_agent = MOAgent._create_agent_from_system_prompt( | |
system_prompt=system_prompt, | |
model_name=main_model, | |
**main_model_kwargs | |
) | |
return cls( | |
main_agent=main_agent, | |
layer_agent=layer_agent, | |
reference_system_prompt=reference_system_prompt, | |
cycles=cycles | |
) | |
def _configure_layer_agent( | |
layer_agent_config: Optional[Dict] = None | |
) -> RunnableSerializable[Dict, Dict]: | |
if not layer_agent_config: | |
layer_agent_config = { | |
'layer_agent_1' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'llama3-8b-8192'}, | |
'layer_agent_2' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'gemma-7b-it'}, | |
'layer_agent_3' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'mixtral-8x7b-32768'} | |
} | |
parallel_chain_map = dict() | |
for key, value in layer_agent_config.items(): | |
chain = MOAgent._create_agent_from_system_prompt( | |
system_prompt=value.pop("system_prompt", SYSTEM_PROMPT), | |
model_name=value.pop("model_name", 'llama3-8b-8192'), | |
**value | |
) | |
parallel_chain_map[key] = RunnablePassthrough() | chain | |
chain = parallel_chain_map | RunnableLambda(MOAgent.concat_response) | |
return chain | |
def _create_agent_from_system_prompt( | |
system_prompt: str = SYSTEM_PROMPT, | |
model_name: str = "llama3-8b-8192", | |
**llm_kwargs | |
) -> RunnableSerializable[Dict, str]: | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
MessagesPlaceholder(variable_name="messages", optional=True), | |
("human", "{input}") | |
]) | |
assert 'helper_response' in prompt.input_variables | |
llm = ChatGroq(model=model_name, **llm_kwargs) | |
chain = prompt | llm | StrOutputParser() | |
return chain | |
def chat( | |
self, | |
input: str, | |
messages: Optional[List[BaseMessage]] = None, | |
cycles: Optional[int] = None, | |
save: bool = True, | |
output_format: Literal['string', 'json'] = 'string' | |
) -> Generator[str | ResponseChunk, None, None]: | |
cycles = cycles or self.cycles | |
llm_inp = { | |
'input': input, | |
'messages': messages or self.chat_memory.load_memory_variables({})['messages'], | |
'helper_response': "" | |
} | |
for cyc in range(cycles): | |
layer_output = self.layer_agent.invoke(llm_inp) | |
l_frm_resp = layer_output['formatted_response'] | |
l_resps = layer_output['responses'] | |
llm_inp = { | |
'input': input, | |
'messages': self.chat_memory.load_memory_variables({})['messages'], | |
'helper_response': l_frm_resp | |
} | |
if output_format == 'json': | |
for l_out in l_resps: | |
yield ResponseChunk( | |
delta=l_out, | |
response_type='intermediate', | |
metadata={'layer': cyc + 1} | |
) | |
stream = self.main_agent.stream(llm_inp) | |
response = "" | |
for chunk in stream: | |
if output_format == 'json': | |
yield ResponseChunk( | |
delta=chunk, | |
response_type='output', | |
metadata={} | |
) | |
else: | |
yield chunk | |
response += chunk | |
if save: | |
self.chat_memory.save_context({'input': input}, {'output': response}) |