File size: 6,181 Bytes
8fd59af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os 
from tenacity import wait_exponential, Retrying, stop_after_attempt
from dotenv import load_dotenv
import google.generativeai as genai
from groq import Groq
import instructor
from openai import OpenAI
from cerebras.cloud.sdk import Cerebras
from limits import storage, strategies, parse
from typing import List, TypedDict, Union, Annotated, Dict, Any, Tuple
import time
from instructor.exceptions import InstructorRetryException

memory_storage = storage.MemoryStorage()
moving_window = strategies.MovingWindowRateLimiter(memory_storage)
rate_limit = parse("10/minute")


MODEL = 'gemini-1.5-flash-latest'
MODEL_FAST = 'gemini-1.5-flash-latest'
MODEL_RAG = 'gemini-1.5-flash-latest'

# Global variable to track LLM usage
_LLM_USAGE = {  MODEL: {"input_tokens": 0, "output_tokens": 0}, 
                MODEL_FAST: {"input_tokens": 0, "output_tokens": 0}, 
                MODEL_RAG: {"input_tokens": 0, "output_tokens": 0}}
_LLM_USAGE_SPLIT = []

def get_llm_usage():
    print(_LLM_USAGE)
    print(_LLM_USAGE_SPLIT)
    # Calculate total usage per function
    function_totals = {}
    for entry in _LLM_USAGE_SPLIT:
        fn = entry['function']
        if fn not in function_totals:
            function_totals[fn] = {'total_input': 0, 'total_output': 0}
        function_totals[fn]['total_input'] += entry['input_usage']
        function_totals[fn]['total_output'] += entry['output_usage']
    return _LLM_USAGE, _LLM_USAGE_SPLIT, function_totals



load_dotenv()

LLM_TYPE = 'google'

def get_llm_instructor():
    if LLM_TYPE == 'groq':
        return instructor.from_groq(Groq(api_key=os.environ["GROQ_API_KEY"]), mode=instructor.Mode.TOOLS)

    elif LLM_TYPE == 'openrouter':
        return instructor.from_openai(OpenAI(api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1"), mode=instructor.Mode.MD_JSON)
    
    elif LLM_TYPE == 'cerebras':
        return instructor.from_cerebras(Cerebras(api_key = os.environ['CEREBRAS_API_KEY']), mode = instructor.Mode.CEREBRAS_JSON)

    elif LLM_TYPE == 'google':
        return instructor.from_gemini(client=genai.GenerativeModel(model_name="models/gemini-1.5-flash-latest", 
                                                                   generation_config=genai.configure(api_key= os.environ['GEMINI_API_KEY'])),
                                                                    mode=instructor.Mode.GEMINI_JSON)


def call_llm(instructions: str, context: dict, response_model: Any, model_type:str = 'slow', additional_messages: List[Dict[str, str]] = None, logging_fn = 'default') -> Any:
        """Standardizes LLM calls with optional retries."""
        messages = [{"role": "system", "content": instructions}]
        if additional_messages:
            messages.extend(additional_messages)

        while not moving_window.test(rate_limit):
            time.sleep(0.1)

        model = MODEL_RAG if model_type == 'rag' else (MODEL if model_type == 'slow' else MODEL_FAST)

        try:
            client = get_llm_instructor()
            if LLM_TYPE == 'google':
                response, completion = client.chat.completions.create_with_completion(  
                            messages=messages,
                            context=context,
                            response_model=response_model
                        )
            else:
                response, completion = client.chat.completions.create_with_completion(  
                            model=model,
                            messages=messages,
                            temperature=0.5,
                            context=context,
                            max_retries=Retrying(stop = stop_after_attempt(2), wait= wait_exponential(multiplier=1.5, min=10, max=60)),
                            response_model=response_model
                        )
            
        except InstructorRetryException as e:
            print(e)
            while not moving_window.test(rate_limit):
                time.sleep(0.1)
            def retry_callback(retry_state):
                # Increase temperature on each retry
                print('retrying....')
                new_temp = 0.1 + (retry_state.attempt_number * 0.2)
                return max(0.1, min(0.9, new_temp))  # Keep between 0.1 and 0.9
            if LLM_TYPE == 'google':
                response, completion = client.chat.completions.create_with_completion(
                                messages=messages,
                                context=context,
                                response_model=response_model,
                                max_retries=Retrying(
                                    stop=stop_after_attempt(3),
                                    wait=wait_exponential(multiplier=1.5, min=10, max=60),
                                    before=retry_callback
                                )
                            )
            else:
                response, completion = client.chat.completions.create_with_completion(
                        model=model,
                        messages=messages,
                        context=context,
                        response_model=response_model,
                        max_retries=3
                    )

            # Update usage statistics
        usage = completion.usage_metadata if LLM_TYPE == 'google' else completion.usage
        input_tokens = usage.prompt_token_count if LLM_TYPE == 'google' else usage.prompt_tokens
        output_tokens = usage.candidates_token_count if LLM_TYPE == 'google' else usage.completion_tokens

        _LLM_USAGE[model]['input_tokens'] += input_tokens
        _LLM_USAGE[model]['output_tokens'] += output_tokens
        _LLM_USAGE_SPLIT.append({
            'function': logging_fn,
            'input_usage': input_tokens,
            'output_usage': output_tokens
        })

        
        return response


if __name__ ==  "__main__":
    class ResponseModel(TypedDict):
        answer: str

    instructions = "What are the key differences between Glean Search and MS Copilot?"
    context = {}
    response_model = ResponseModel
    print(call_llm(instructions, context, response_model))