Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
__author__ = "Yash Kumar Lal, Github@ykl7" | |
import os | |
import openai | |
from openai import OpenAI | |
import anthropic | |
import time | |
import random | |
random.seed(1234) | |
class LLMReasoner(): | |
def __init__(self, options): | |
if options["model_family"] == "OpenAI": | |
self.client = OpenAI(api_key=options["API_KEY"]) | |
elif options["model_family"] == "Anthropic": | |
os.environ["ANTHROPIC_API_KEY"] = options["API_KEY"] | |
self.client = anthropic.Anthropic() | |
self.model_family = options["model_family"] | |
self.model_name = options["model_name"] | |
self.max_tokens = options["max_tokens"] | |
self.temp = 0.0 if "temperature" not in options else options["temperature"] | |
self.top_p = 1.0 if "top_p" not in options else options["top_p"] | |
self.frequency_penalty = 0.0 if "frequency_penalty" not in options else options["frequency_penalty"] | |
self.presence_penalty = 0.0 if "presence_penalty" not in options else options["presence_penalty"] | |
def make_openai_chat_completions_api_call(self, message): | |
prompt = [{"role": "user", "content": message}] | |
try: | |
if "gpt-4o" in self.model_name: | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=prompt, | |
temperature=self.temp, | |
max_completion_tokens=self.max_tokens, | |
top_p=self.top_p, | |
frequency_penalty=self.frequency_penalty, | |
presence_penalty=self.presence_penalty | |
) | |
elif "o3-mini" in self.model_name: | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=prompt, | |
reasoning_effort="medium" | |
) | |
return self.parse_chat_completions_api_response(response) | |
except openai.APIConnectionError as e: | |
print("The server could not be reached") | |
print(e.__cause__) # an underlying Exception, likely raised within httpx. | |
time.sleep(60) | |
return self.make_openai_chat_completions_api_call(prompt) | |
except openai.RateLimitError as e: | |
print("Rate limit error hit") | |
exit() | |
except openai.NotFoundError as e: | |
print("Model not found") | |
exit() | |
except openai.APIStatusError as e: | |
print("Another non-200-range status code was received") | |
print(e.status_code) | |
print(e) | |
time.sleep(60) | |
return self.make_openai_chat_completions_api_call(prompt) | |
def parse_chat_completions_api_response(self, response): | |
# print(response.model_dump()) | |
choices = response.choices | |
main_response = choices[0].message | |
main_response_message, main_response_role = main_response.content, main_response.role | |
return main_response_message, response | |
def call_claude(self, claude_prompt=""): | |
try: | |
message = self.client.messages.create( | |
model=self.model_name, | |
max_tokens=self.max_tokens, | |
temperature=self.temp, | |
system="", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": claude_prompt | |
} | |
] | |
} | |
] | |
) | |
except Exception as e: | |
breakpoint() | |
print(e) | |
time.sleep(30) | |
call_claude(self, claude_prompt) | |
if message.content[0].type == "text": | |
return message.content[0].text, message | |
else: | |
return "Error", message | |
def run_inference(self, prompt=[]): | |
if self.model_family == "OpenAI": | |
response_text, response = self.make_openai_chat_completions_api_call(prompt) | |
elif self.model_family == "Anthropic": | |
response_text, response = self.call_claude(prompt) | |
return response_text |