Spaces:
Sleeping
Sleeping
import os | |
from tenacity import wait_exponential, Retrying, stop_after_attempt | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
from groq import Groq | |
import instructor | |
from openai import OpenAI | |
from cerebras.cloud.sdk import Cerebras | |
from limits import storage, strategies, parse | |
from typing import List, TypedDict, Union, Annotated, Dict, Any, Tuple | |
import time | |
from instructor.exceptions import InstructorRetryException | |
memory_storage = storage.MemoryStorage() | |
moving_window = strategies.MovingWindowRateLimiter(memory_storage) | |
rate_limit = parse("10/minute") | |
MODEL = 'gemini-1.5-flash-latest' | |
MODEL_FAST = 'gemini-1.5-flash-latest' | |
MODEL_RAG = 'gemini-1.5-flash-latest' | |
# Global variable to track LLM usage | |
_LLM_USAGE = { MODEL: {"input_tokens": 0, "output_tokens": 0}, | |
MODEL_FAST: {"input_tokens": 0, "output_tokens": 0}, | |
MODEL_RAG: {"input_tokens": 0, "output_tokens": 0}} | |
_LLM_USAGE_SPLIT = [] | |
def get_llm_usage(): | |
print(_LLM_USAGE) | |
print(_LLM_USAGE_SPLIT) | |
# Calculate total usage per function | |
function_totals = {} | |
for entry in _LLM_USAGE_SPLIT: | |
fn = entry['function'] | |
if fn not in function_totals: | |
function_totals[fn] = {'total_input': 0, 'total_output': 0} | |
function_totals[fn]['total_input'] += entry['input_usage'] | |
function_totals[fn]['total_output'] += entry['output_usage'] | |
return _LLM_USAGE, _LLM_USAGE_SPLIT, function_totals | |
load_dotenv() | |
LLM_TYPE = 'google' | |
def get_llm_instructor(): | |
if LLM_TYPE == 'groq': | |
return instructor.from_groq(Groq(api_key=os.environ["GROQ_API_KEY"]), mode=instructor.Mode.TOOLS) | |
elif LLM_TYPE == 'openrouter': | |
return instructor.from_openai(OpenAI(api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"), mode=instructor.Mode.MD_JSON) | |
elif LLM_TYPE == 'cerebras': | |
return instructor.from_cerebras(Cerebras(api_key = os.environ['CEREBRAS_API_KEY']), mode = instructor.Mode.CEREBRAS_JSON) | |
elif LLM_TYPE == 'google': | |
return instructor.from_gemini(client=genai.GenerativeModel(model_name="models/gemini-1.5-flash-latest", | |
generation_config=genai.configure(api_key= os.environ['GEMINI_API_KEY'])), | |
mode=instructor.Mode.GEMINI_JSON) | |
def call_llm(instructions: str, context: dict, response_model: Any, model_type:str = 'slow', additional_messages: List[Dict[str, str]] = None, logging_fn = 'default') -> Any: | |
"""Standardizes LLM calls with optional retries.""" | |
messages = [{"role": "system", "content": instructions}] | |
if additional_messages: | |
messages.extend(additional_messages) | |
while not moving_window.test(rate_limit): | |
time.sleep(0.1) | |
model = MODEL_RAG if model_type == 'rag' else (MODEL if model_type == 'slow' else MODEL_FAST) | |
try: | |
client = get_llm_instructor() | |
if LLM_TYPE == 'google': | |
response, completion = client.chat.completions.create_with_completion( | |
messages=messages, | |
context=context, | |
response_model=response_model | |
) | |
else: | |
response, completion = client.chat.completions.create_with_completion( | |
model=model, | |
messages=messages, | |
temperature=0.5, | |
context=context, | |
max_retries=Retrying(stop = stop_after_attempt(2), wait= wait_exponential(multiplier=1.5, min=10, max=60)), | |
response_model=response_model | |
) | |
except InstructorRetryException as e: | |
print(e) | |
while not moving_window.test(rate_limit): | |
time.sleep(0.1) | |
def retry_callback(retry_state): | |
# Increase temperature on each retry | |
print('retrying....') | |
new_temp = 0.1 + (retry_state.attempt_number * 0.2) | |
return max(0.1, min(0.9, new_temp)) # Keep between 0.1 and 0.9 | |
if LLM_TYPE == 'google': | |
response, completion = client.chat.completions.create_with_completion( | |
messages=messages, | |
context=context, | |
response_model=response_model, | |
max_retries=Retrying( | |
stop=stop_after_attempt(3), | |
wait=wait_exponential(multiplier=1.5, min=10, max=60), | |
before=retry_callback | |
) | |
) | |
else: | |
response, completion = client.chat.completions.create_with_completion( | |
model=model, | |
messages=messages, | |
context=context, | |
response_model=response_model, | |
max_retries=3 | |
) | |
# Update usage statistics | |
usage = completion.usage_metadata if LLM_TYPE == 'google' else completion.usage | |
input_tokens = usage.prompt_token_count if LLM_TYPE == 'google' else usage.prompt_tokens | |
output_tokens = usage.candidates_token_count if LLM_TYPE == 'google' else usage.completion_tokens | |
_LLM_USAGE[model]['input_tokens'] += input_tokens | |
_LLM_USAGE[model]['output_tokens'] += output_tokens | |
_LLM_USAGE_SPLIT.append({ | |
'function': logging_fn, | |
'input_usage': input_tokens, | |
'output_usage': output_tokens | |
}) | |
return response | |
if __name__ == "__main__": | |
class ResponseModel(TypedDict): | |
answer: str | |
instructions = "What are the key differences between Glean Search and MS Copilot?" | |
context = {} | |
response_model = ResponseModel | |
print(call_llm(instructions, context, response_model)) |