File size: 4,526 Bytes
121f84a
f24a24a
121f84a
f24a24a
121f84a
f24a24a
121f84a
f24a24a
 
 
 
 
121f84a
 
 
 
f24a24a
121f84a
 
 
 
 
f24a24a
 
 
 
 
 
 
 
 
 
121f84a
 
 
f24a24a
121f84a
 
 
f24a24a
121f84a
f24a24a
 
 
121f84a
 
 
 
 
 
f24a24a
121f84a
 
 
 
 
 
 
 
 
 
 
 
 
 
f24a24a
121f84a
f24a24a
 
121f84a
 
 
f24a24a
 
121f84a
f24a24a
121f84a
 
f24a24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import json
import logging
import requests
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Setup logging
logging.basicConfig(level=logging.INFO)


def check_server_health(cloud_gateway_api: str, header: dict) -> bool:
    """
    Use the appropriate API endpoint to check the server health.
    Args:
        cloud_gateway_api: API endpoint to probe.
        header: Header for Authorization.

    Returns:
        True if server is active, false otherwise.
    """
    try:
        response = requests.get(
            cloud_gateway_api + "model/info",
            headers=header,
            verify=False,
        )
        response.raise_for_status()
        return True
    except requests.RequestException as e:
        logging.error(f"Failed to check server health: {e}")
        return False


def request_generation(
    header: dict,
    message: str,
    system_prompt: str,
    cloud_gateway_api: str,
    model_name: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.3,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
):
    """
    Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
    token-by-token generation from LLM.

    Args:
        header: authorization header for the API.
        message: prompt from the user.
        system_prompt: system prompt to append.
        cloud_gateway_api (str): API endpoint to send the request.
        max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
        temperature: the value used to module the next token probabilities.
        top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
                or higher are kept for generation.
        repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.

    Returns:

    """

    payload = {
        "model": model_name,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "stream": True,  # Enable streaming
        "serving_runtime": "vllm",
    }

    try:
        response = requests.post(
            cloud_gateway_api + "chat/conversation",
            headers=header,
            json=payload,
            verify=False,
        )
        response.raise_for_status()

        # Append the conversation ID with the key X-Conversation-ID to the header
        header["X-Conversation-ID"] = response.json()["conversationId"]

        with requests.get(
            cloud_gateway_api + f"conversation/stream",
            headers=header,
            verify=False,
            stream=True,
        ) as response:
            for chunk in response.iter_lines():
                if chunk:
                    # Convert the chunk from bytes to a string and then parse it as json
                    chunk_str = chunk.decode("utf-8")

                    # Remove the `data: ` prefix from the chunk if it exists
                    for _ in range(2):
                        if chunk_str.startswith("data: "):
                            chunk_str = chunk_str[len("data: ") :]

                    # Skip empty chunks
                    if chunk_str.strip() == "[DONE]":
                        break

                    # Parse the chunk into a JSON object
                    try:
                        chunk_json = json.loads(chunk_str)

                        # Extract the "content" field from the choices
                        if "choices" in chunk_json and chunk_json["choices"]:
                            content = chunk_json["choices"][0]["delta"].get(
                                "content", ""
                            )
                        else:
                            content = ""

                        # Print the generated content as it's streamed
                        if content:
                            yield content
                    except json.JSONDecodeError:
                        # Handle any potential errors in decoding
                        continue
    except requests.RequestException as e:
        logging.error(f"Failed to generate response: {e}")
        yield "Server not responding. Please try again later."