import os from tenacity import wait_exponential, Retrying, stop_after_attempt from dotenv import load_dotenv import google.generativeai as genai from groq import Groq import instructor from openai import OpenAI from cerebras.cloud.sdk import Cerebras from limits import storage, strategies, parse from typing import List, TypedDict, Union, Annotated, Dict, Any, Tuple import time from instructor.exceptions import InstructorRetryException memory_storage = storage.MemoryStorage() moving_window = strategies.MovingWindowRateLimiter(memory_storage) rate_limit = parse("10/minute") MODEL = 'gemini-1.5-flash-latest' MODEL_FAST = 'gemini-1.5-flash-latest' MODEL_RAG = 'gemini-1.5-flash-latest' # Global variable to track LLM usage _LLM_USAGE = { MODEL: {"input_tokens": 0, "output_tokens": 0}, MODEL_FAST: {"input_tokens": 0, "output_tokens": 0}, MODEL_RAG: {"input_tokens": 0, "output_tokens": 0}} _LLM_USAGE_SPLIT = [] def get_llm_usage(): print(_LLM_USAGE) print(_LLM_USAGE_SPLIT) # Calculate total usage per function function_totals = {} for entry in _LLM_USAGE_SPLIT: fn = entry['function'] if fn not in function_totals: function_totals[fn] = {'total_input': 0, 'total_output': 0} function_totals[fn]['total_input'] += entry['input_usage'] function_totals[fn]['total_output'] += entry['output_usage'] return _LLM_USAGE, _LLM_USAGE_SPLIT, function_totals load_dotenv() LLM_TYPE = 'google' def get_llm_instructor(): if LLM_TYPE == 'groq': return instructor.from_groq(Groq(api_key=os.environ["GROQ_API_KEY"]), mode=instructor.Mode.TOOLS) elif LLM_TYPE == 'openrouter': return instructor.from_openai(OpenAI(api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"), mode=instructor.Mode.MD_JSON) elif LLM_TYPE == 'cerebras': return instructor.from_cerebras(Cerebras(api_key = os.environ['CEREBRAS_API_KEY']), mode = instructor.Mode.CEREBRAS_JSON) elif LLM_TYPE == 'google': return instructor.from_gemini(client=genai.GenerativeModel(model_name="models/gemini-1.5-flash-latest", generation_config=genai.configure(api_key= os.environ['GEMINI_API_KEY'])), mode=instructor.Mode.GEMINI_JSON) def call_llm(instructions: str, context: dict, response_model: Any, model_type:str = 'slow', additional_messages: List[Dict[str, str]] = None, logging_fn = 'default') -> Any: """Standardizes LLM calls with optional retries.""" messages = [{"role": "system", "content": instructions}] if additional_messages: messages.extend(additional_messages) while not moving_window.test(rate_limit): time.sleep(0.1) model = MODEL_RAG if model_type == 'rag' else (MODEL if model_type == 'slow' else MODEL_FAST) try: client = get_llm_instructor() if LLM_TYPE == 'google': response, completion = client.chat.completions.create_with_completion( messages=messages, context=context, response_model=response_model ) else: response, completion = client.chat.completions.create_with_completion( model=model, messages=messages, temperature=0.5, context=context, max_retries=Retrying(stop = stop_after_attempt(2), wait= wait_exponential(multiplier=1.5, min=10, max=60)), response_model=response_model ) except InstructorRetryException as e: print(e) while not moving_window.test(rate_limit): time.sleep(0.1) def retry_callback(retry_state): # Increase temperature on each retry print('retrying....') new_temp = 0.1 + (retry_state.attempt_number * 0.2) return max(0.1, min(0.9, new_temp)) # Keep between 0.1 and 0.9 if LLM_TYPE == 'google': response, completion = client.chat.completions.create_with_completion( messages=messages, context=context, response_model=response_model, max_retries=Retrying( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1.5, min=10, max=60), before=retry_callback ) ) else: response, completion = client.chat.completions.create_with_completion( model=model, messages=messages, context=context, response_model=response_model, max_retries=3 ) # Update usage statistics usage = completion.usage_metadata if LLM_TYPE == 'google' else completion.usage input_tokens = usage.prompt_token_count if LLM_TYPE == 'google' else usage.prompt_tokens output_tokens = usage.candidates_token_count if LLM_TYPE == 'google' else usage.completion_tokens _LLM_USAGE[model]['input_tokens'] += input_tokens _LLM_USAGE[model]['output_tokens'] += output_tokens _LLM_USAGE_SPLIT.append({ 'function': logging_fn, 'input_usage': input_tokens, 'output_usage': output_tokens }) return response if __name__ == "__main__": class ResponseModel(TypedDict): answer: str instructions = "What are the key differences between Glean Search and MS Copilot?" context = {} response_model = ResponseModel print(call_llm(instructions, context, response_model))