web-researcher / llm_config.py
anirudhs's picture
added researcher files
8fd59af
import os
from tenacity import wait_exponential, Retrying, stop_after_attempt
from dotenv import load_dotenv
import google.generativeai as genai
from groq import Groq
import instructor
from openai import OpenAI
from cerebras.cloud.sdk import Cerebras
from limits import storage, strategies, parse
from typing import List, TypedDict, Union, Annotated, Dict, Any, Tuple
import time
from instructor.exceptions import InstructorRetryException
memory_storage = storage.MemoryStorage()
moving_window = strategies.MovingWindowRateLimiter(memory_storage)
rate_limit = parse("10/minute")
MODEL = 'gemini-1.5-flash-latest'
MODEL_FAST = 'gemini-1.5-flash-latest'
MODEL_RAG = 'gemini-1.5-flash-latest'
# Global variable to track LLM usage
_LLM_USAGE = { MODEL: {"input_tokens": 0, "output_tokens": 0},
MODEL_FAST: {"input_tokens": 0, "output_tokens": 0},
MODEL_RAG: {"input_tokens": 0, "output_tokens": 0}}
_LLM_USAGE_SPLIT = []
def get_llm_usage():
print(_LLM_USAGE)
print(_LLM_USAGE_SPLIT)
# Calculate total usage per function
function_totals = {}
for entry in _LLM_USAGE_SPLIT:
fn = entry['function']
if fn not in function_totals:
function_totals[fn] = {'total_input': 0, 'total_output': 0}
function_totals[fn]['total_input'] += entry['input_usage']
function_totals[fn]['total_output'] += entry['output_usage']
return _LLM_USAGE, _LLM_USAGE_SPLIT, function_totals
load_dotenv()
LLM_TYPE = 'google'
def get_llm_instructor():
if LLM_TYPE == 'groq':
return instructor.from_groq(Groq(api_key=os.environ["GROQ_API_KEY"]), mode=instructor.Mode.TOOLS)
elif LLM_TYPE == 'openrouter':
return instructor.from_openai(OpenAI(api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"), mode=instructor.Mode.MD_JSON)
elif LLM_TYPE == 'cerebras':
return instructor.from_cerebras(Cerebras(api_key = os.environ['CEREBRAS_API_KEY']), mode = instructor.Mode.CEREBRAS_JSON)
elif LLM_TYPE == 'google':
return instructor.from_gemini(client=genai.GenerativeModel(model_name="models/gemini-1.5-flash-latest",
generation_config=genai.configure(api_key= os.environ['GEMINI_API_KEY'])),
mode=instructor.Mode.GEMINI_JSON)
def call_llm(instructions: str, context: dict, response_model: Any, model_type:str = 'slow', additional_messages: List[Dict[str, str]] = None, logging_fn = 'default') -> Any:
"""Standardizes LLM calls with optional retries."""
messages = [{"role": "system", "content": instructions}]
if additional_messages:
messages.extend(additional_messages)
while not moving_window.test(rate_limit):
time.sleep(0.1)
model = MODEL_RAG if model_type == 'rag' else (MODEL if model_type == 'slow' else MODEL_FAST)
try:
client = get_llm_instructor()
if LLM_TYPE == 'google':
response, completion = client.chat.completions.create_with_completion(
messages=messages,
context=context,
response_model=response_model
)
else:
response, completion = client.chat.completions.create_with_completion(
model=model,
messages=messages,
temperature=0.5,
context=context,
max_retries=Retrying(stop = stop_after_attempt(2), wait= wait_exponential(multiplier=1.5, min=10, max=60)),
response_model=response_model
)
except InstructorRetryException as e:
print(e)
while not moving_window.test(rate_limit):
time.sleep(0.1)
def retry_callback(retry_state):
# Increase temperature on each retry
print('retrying....')
new_temp = 0.1 + (retry_state.attempt_number * 0.2)
return max(0.1, min(0.9, new_temp)) # Keep between 0.1 and 0.9
if LLM_TYPE == 'google':
response, completion = client.chat.completions.create_with_completion(
messages=messages,
context=context,
response_model=response_model,
max_retries=Retrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1.5, min=10, max=60),
before=retry_callback
)
)
else:
response, completion = client.chat.completions.create_with_completion(
model=model,
messages=messages,
context=context,
response_model=response_model,
max_retries=3
)
# Update usage statistics
usage = completion.usage_metadata if LLM_TYPE == 'google' else completion.usage
input_tokens = usage.prompt_token_count if LLM_TYPE == 'google' else usage.prompt_tokens
output_tokens = usage.candidates_token_count if LLM_TYPE == 'google' else usage.completion_tokens
_LLM_USAGE[model]['input_tokens'] += input_tokens
_LLM_USAGE[model]['output_tokens'] += output_tokens
_LLM_USAGE_SPLIT.append({
'function': logging_fn,
'input_usage': input_tokens,
'output_usage': output_tokens
})
return response
if __name__ == "__main__":
class ResponseModel(TypedDict):
answer: str
instructions = "What are the key differences between Glean Search and MS Copilot?"
context = {}
response_model = ResponseModel
print(call_llm(instructions, context, response_model))