File size: 1,902 Bytes
640b1c8
 
 
e87abff
640b1c8
 
 
 
e87abff
640b1c8
 
 
 
 
 
 
e87abff
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e87abff
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# src/llms/openai_llm.py
import openai
from typing import Optional, List
from openai import OpenAI  # Import the new client

from .base_llm import BaseLLM

class OpenAILanguageModel(BaseLLM):
    def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
        """
        Initialize OpenAI Language Model
        
        Args:
            api_key (str): OpenAI API key
            model (str): Name of the OpenAI model to use
        """
        self.client = OpenAI(api_key=api_key)  # Use the new client
        self.model = model
    
    def generate(
        self, 
        prompt: str, 
        max_tokens: Optional[int] = 150,
        temperature: float = 0.7,
        **kwargs
    ) -> str:
        """
        Generate response using OpenAI API
        
        Args:
            prompt (str): Input prompt
            max_tokens (Optional[int]): Maximum tokens to generate
            temperature (float): Sampling temperature
        
        Returns:
            str: Generated response
        """
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=temperature,
            **kwargs
        )
        
        return response.choices[0].message.content.strip()
    
    def tokenize(self, text: str) -> List[str]:
        """
        Tokenize text using OpenAI tokenizer
        
        Args:
            text (str): Input text to tokenize
        
        Returns:
            List[str]: List of tokens
        """
        return text.split()
    
    def count_tokens(self, text: str) -> int:
        """
        Count tokens in the text
        
        Args:
            text (str): Input text to count tokens
        
        Returns:
            int: Number of tokens
        """
        return len(self.tokenize(text))