File size: 3,543 Bytes
c368c0e
 
 
 
7a360f2
c368c0e
7a360f2
 
 
 
c368c0e
7a360f2
 
c368c0e
7a360f2
 
 
 
c368c0e
7a360f2
 
 
 
 
 
 
c368c0e
7a360f2
 
 
 
 
 
 
 
 
c368c0e
7a360f2
 
 
 
 
 
 
 
 
 
 
c368c0e
 
7a360f2
 
 
 
c368c0e
7a360f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c368c0e
7a360f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import os
import requests
import json
from typing import Optional

# Define constants
TOGETHER_API_URL = "https://api.together.xyz/v1/chat/completions"
MODEL_NAME = "NousResearch/Nous-Hermes-2-Yi-34B"
HEADERS = {"accept": "application/json", "content-type": "application/json"}

# Initialize message history
all_message = [{"role": "system", "content": "... system message ..."}]

def validate_input(message: str) -> bool:
    """Validate the user input before sending it to the API."""
    # Add input validation logic here, such as checking length, content, etc.
    return True  # Placeholder for actual validation logic

def get_token() -> Optional[str]:
    """Retrieve the API token from the environment variables."""
    try:
        return os.environ['TOGETHER_API_KEY']
    except KeyError:
        print("The TOGETHER_API_KEY environment variable is not set.")
        return None

def post_request(payload: dict, headers: dict) -> requests.Response:
    """Send a POST request to the API."""
    try:
        response = requests.post(TOGETHER_API_URL, json=payload, headers=headers, stream=True)
        response.raise_for_status()
        return response
    except requests.exceptions.RequestException as e:
        print(f"An error occurred while making the API request: {e}")
        return None

def process_stream(response: requests.Response) -> str:
    """Process the streamed response from the API."""
    assistant_response = ""
    try:
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode('utf-8').strip()
                if decoded_line == "data: [DONE]":
                    return assistant_response
                elif decoded_line.startswith("data: "):
                    decoded_line = decoded_line.removeprefix("data: ")
                    chunk_data = json.loads(decoded_line)
                    content = chunk_data['choices'][0]['delta']['content']
                    assistant_response += content
    except (json.JSONDecodeError, KeyError) as e:
        print(f"An error occurred while processing the stream: {e}")
    return assistant_response

def get_streamed_response(message: str, history: list) -> str:
    """Main function to interact with the chat API."""
    global all_message
    
    # Validate input
    if not validate_input(message):
        return "Invalid input."
    
    # Prepare the message history
    all_message.append({"role": "user", "content": message})
    
    # Retrieve the API token
    api_key = get_token()
    if not api_key:
        return "Unable to retrieve the API key."
    
    # Set up the headers with the API key
    headers = HEADERS.copy()
    headers["Authorization"] = f"Bearer {api_key}"
    
    # Prepare the payload for the API request
    payload = {
        "model": MODEL_NAME,
        "temperature":  1.1,
        "top_p":  0.9,
        "top_k":  50,
        "repetition_penalty":  1,
        "n":  1,
        "messages": all_message,
        "stream_tokens": True,
    }
    
    # Send the request and process the stream
    response = post_request(payload, headers)
    if response:
        assistant_response = process_stream(response)
        all_message.append({"role": "assistant", "content": assistant_response})
        return assistant_response
    else:
        return "Failed to get a response from the API."

# Launch the Gradio interface
gr.ChatInterface(fn=get_streamed_response, title="TherapistGPT", description="...", retry_btn="Regenerate 🔁").launch()