File size: 4,261 Bytes
9a31c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da4d7e6
 
9a31c8f
22d49db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a31c8f
 
 
 
 
22d49db
9a31c8f
 
 
 
 
 
 
 
 
0993480
9a31c8f
22d49db
9a31c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
__author__ = "Yash Kumar Lal, Github@ykl7"

import os
import openai
from openai import OpenAI
import anthropic
import time
import random

random.seed(1234)

class LLMReasoner():

    def __init__(self, options):

        if options["model_family"] == "OpenAI":
            self.client = OpenAI(api_key=options["API_KEY"])
        elif options["model_family"] == "Anthropic":
            os.environ["ANTHROPIC_API_KEY"] = options["API_KEY"]
            self.client = anthropic.Anthropic()

        self.model_family = options["model_family"]
        self.model_name = options["model_name"]
        self.max_tokens = options["max_tokens"]
        self.temp = 0.0 if "temperature" not in options else options["temperature"]
        self.top_p = 1.0 if "top_p" not in options else options["top_p"]
        self.frequency_penalty = 0.0 if "frequency_penalty" not in options else options["frequency_penalty"]
        self.presence_penalty = 0.0 if "presence_penalty" not in options else options["presence_penalty"]

    def make_openai_chat_completions_api_call(self, message):
        prompt = [{"role": "user", "content": message}]
        try:
            if "gpt-4o" in self.model_name:
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=prompt,
                    temperature=self.temp,
                    max_completion_tokens=self.max_tokens,
                    top_p=self.top_p,
                    frequency_penalty=self.frequency_penalty,
                    presence_penalty=self.presence_penalty
                )
            elif "o3-mini" in self.model_name:
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=prompt,
                    reasoning_effort="medium"
                )
            return self.parse_chat_completions_api_response(response)
        except openai.APIConnectionError as e:
            print("The server could not be reached")
            print(e.__cause__)  # an underlying Exception, likely raised within httpx.
            time.sleep(60)
            return self.make_openai_chat_completions_api_call(prompt)
        except openai.RateLimitError as e:
            print("Rate limit error hit")
            exit()
        except openai.NotFoundError as e:
            print("Model not found")
            exit()
        except openai.APIStatusError as e:
            print("Another non-200-range status code was received")
            print(e.status_code)
            print(e)
            time.sleep(60)
            return self.make_openai_chat_completions_api_call(prompt)

    def parse_chat_completions_api_response(self, response):
        # print(response.model_dump())
        choices = response.choices
        main_response = choices[0].message
        main_response_message, main_response_role = main_response.content, main_response.role
        return main_response_message, response

    def call_claude(self, claude_prompt=""):
        try:
            message = self.client.messages.create(
                model=self.model_name,
                max_tokens=self.max_tokens,
                temperature=self.temp,
                system="",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": claude_prompt
                            }
                        ]
                    }
                ]
            )
        except Exception as e:
            breakpoint()
            print(e)
            time.sleep(30)
            call_claude(self, claude_prompt)
        if message.content[0].type == "text":
            return message.content[0].text, message
        else:
            return "Error", message

    def run_inference(self, prompt=[]):

        if self.model_family == "OpenAI":
            response_text, response = self.make_openai_chat_completions_api_call(prompt)
        elif self.model_family == "Anthropic":
            response_text, response = self.call_claude(prompt)

        return response_text