File size: 2,664 Bytes
7b2e5db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sys

from langchain_anthropic import ChatAnthropic
from langchain_fireworks import ChatFireworks
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import ChatOpenAI

sys.path.append(os.getcwd())
import KEYS
from research_assistant.app_logging import app_logger


def set_api_key(env_var: str, api_key: str):
    os.environ[env_var] = api_key


class Agent:
    def __init__(self, model_name: str):
        model_classes = {
            "gpt": (
                (ChatOpenAI, "OPENAI_API_KEY", KEYS.OPENAI)  # type: ignore
                if "OPENAI" in KEYS.__dict__
                else (None, None, None)
            ),
            "claude": (
                (ChatAnthropic, "ANTHROPIC_API_KEY", KEYS.ANTHROPIC)  # type: ignore
                if "ANTHROPIC" in KEYS.__dict__
                else (None, None, None)
            ),
            "gemini": (
                (ChatVertexAI, "GOOGLE_API_KEY", KEYS.VERTEX_AI)  # type: ignore
                if "VERTEX_AI" in KEYS.__dict__
                else (None, None, None)
            ),
            "fireworks": (
                (ChatFireworks, "FIREWORKS_API_KEY", KEYS.FIREWORKS_AI)  # type: ignore
                if "FIREWORKS_AI" in KEYS.__dict__
                else (None, None, None)
            ),
        }
        max_tokens_map = {
            "gpt-3.5": 16000,
            "gpt-4": 8000,
            "gpt-4o-mini": 8000,
            "llama-v3p2-1b-instruct": 128000,
            "llama-v3p2-3b-instruct": 128000,
            "llama-v3p1-8b-instruct": 128000,
            "llama-v3p1-70b-instruct": 128000,
            "llama-v3p1-405b-instruct": 128000,
            "mixtral-8x22b-instruct": 64000,
            "mixtral-8x7b-instruct": 32000,
            "mixtral-8x7b-instruct-hf": 32000,
            "qwen2p5-72b-instruct": 32000,
            "gemma2-9b-it": 8000,
            "llama-v3-8b-instruct": 8000,
            "llama-v3-70b-instruct": 8000,
            "llama-v3-70b-instruct-hf": 8000,
        }
        for key, (model_class, env_var, api_key) in model_classes.items():
            if model_class is not None and key in model_name:
                set_api_key(env_var, api_key)  # type: ignore
                model = model_class(model=model_name, temperature=0.5)  # type: ignore
                max_tokens = max_tokens_map.get(model_name, 128000)
                break
        else:
            raise ValueError(f"Model {model_name} not supported")

        app_logger.info(f"Model {model_name} is initialized successfully")
        self.model = model
        self.max_tokens = max_tokens

    def get_model(self):
        return self.model