File size: 1,558 Bytes
ea99abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from dotenv import load_dotenv
from typing import List
from langchain.tools import BaseTool
from langchain.agents import initialize_agent, AgentType
_ = load_dotenv()
class LLM:
def __init__(
self,
model: str = "gemini-2.0-flash",
model_provider: str = "google_genai",
temperature: float = 0.0,
max_tokens: int = 1000
):
self.chat_model = init_chat_model(
model=model,
model_provider=model_provider,
temperature=temperature,
max_tokens=max_tokens,
)
def generate(self, prompt: str) -> str:
message = HumanMessage(content=prompt)
response = self.chat_model.invoke([message])
return response.content
def bind_tools(self, tools: List[BaseTool], agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION):
"""
Bind LangChain tools to this model and return an AgentExecutor.
"""
return initialize_agent(
tools,
self.chat_model,
agent=agent_type,
verbose=False
)
def set_temperature(self, temperature: float):
"""
Set the temperature for the chat model.
"""
self.chat_model.temperature = temperature
def set_max_tokens(self, max_tokens: int):
"""
Set the maximum number of tokens for the chat model.
"""
self.chat_model.max_tokens = max_tokens
|