File size: 3,885 Bytes
3702f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json
import uuid
from typing import Iterator, Union, List, Dict
from dotenv import load_dotenv; load_dotenv()
import os
import requests

AVAILABLE_MODELS = [
    "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
    "o1-mini",
    "claude-3-sonnet-20240229",
    "gemini-1.5-pro",
    "gemini-1.5-flash",
    "o1-preview",
    "gpt-4o"
]

def API_Inference(

    messages: List[Dict[str, str]],

    model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",

    stream: bool = False,

    max_tokens: int = 4000,

    temperature: float = 0.7,

    top_p: float = 0.95,

) -> Union[str, Iterator[str], None]:
    if model not in AVAILABLE_MODELS:
        raise ValueError(
            f"Model {model} not available. Available models: {', '.join(AVAILABLE_MODELS)}"
        )
    
    if model == "claude-3-sonnet-20240229":
        messages = [{"role": "system", "content": "."}] + [msg for msg in messages if msg["role"] != "system"]

    api_endpoint = os.environ.get("AMIGO_BASE_URL")
    headers = {
        "Accept": "*/*",
        "Accept-Encoding": "gzip, deflate, br, zstd",
        "Authorization": "Bearer ",
        "Content-Type": "application/json",
        "User-Agent": (
            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
            "AppleWebKit/537.36 (KHTML, like Gecko) "
            "Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0"
        ),
        "X-Device-UUID": str(uuid.uuid4()),
    }

    payload = {
        "messages": messages,
        "model": model,
        "max_tokens": max_tokens,
        "stream": stream,
        "presence_penalty": 0,
        "temperature": temperature,
        "top_p": top_p,
    }

    try:
        response = requests.post(api_endpoint, headers=headers, json=payload, stream=stream)
        response.raise_for_status()
    except requests.exceptions.RequestException as e:
        print("An error occurred while making the request:", e)
        return None

    def process_response() -> Iterator[str]:
        for line in response.iter_lines():
            if line:
                # Decode the line from bytes to string
                decoded_line = line.decode('utf-8').strip()
                if decoded_line.startswith("data: "):
                    data_str = decoded_line[6:]
                    if data_str == "[DONE]":
                        break
                    try:
                        # Load the JSON data
                        data_json = json.loads(data_str)
                        
                        # Extract the content from the response
                        choices = data_json.get("choices", [])
                        if choices:
                            delta = choices[0].get("delta", {})
                            content = delta.get("content", "")
                            if content:
                                yield content 
                    except json.JSONDecodeError:
                        print(f"Received non-JSON data: {data_str}")

    if stream:
        return process_response()
    else:
        return "".join(process_response())

if __name__ == "__main__":
    # Example usage with the new format
    conversation = [
        {"role": "system", "content": "You are a helpful and friendly AI assistant."},
        {"role": "user", "content": "What is the capital of France?"},
        {"role": "assistant", "content": "Paris"},
        {"role": "user", "content": "Who are you. Are you GPT-4o or gpt-3.5?"}
    ]

    # For non-streaming response
    response = API_Inference(conversation, stream=False, model="claude-3-sonnet-20240229")
    print(response)

    print("--" * 50)

    # # For streaming response
    for chunk in API_Inference(conversation, stream=True, model="gpt-4o"):
        print(chunk, end="", flush=True)