Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import requests | |
import json | |
from typing import Optional | |
# Define constants | |
TOGETHER_API_URL = "https://api.together.xyz/v1/chat/completions" | |
MODEL_NAME = "NousResearch/Nous-Hermes-2-Yi-34B" | |
HEADERS = {"accept": "application/json", "content-type": "application/json"} | |
# Initialize message history | |
all_message = [{"role": "system", "content": "... system message ..."}] | |
def validate_input(message: str) -> bool: | |
"""Validate the user input before sending it to the API.""" | |
# Add input validation logic here, such as checking length, content, etc. | |
return True # Placeholder for actual validation logic | |
def get_token() -> Optional[str]: | |
"""Retrieve the API token from the environment variables.""" | |
try: | |
return os.environ['TOGETHER_API_KEY'] | |
except KeyError: | |
print("The TOGETHER_API_KEY environment variable is not set.") | |
return None | |
def post_request(payload: dict, headers: dict) -> requests.Response: | |
"""Send a POST request to the API.""" | |
try: | |
response = requests.post(TOGETHER_API_URL, json=payload, headers=headers, stream=True) | |
response.raise_for_status() | |
return response | |
except requests.exceptions.RequestException as e: | |
print(f"An error occurred while making the API request: {e}") | |
return None | |
def process_stream(response: requests.Response) -> str: | |
"""Process the streamed response from the API.""" | |
assistant_response = "" | |
try: | |
for line in response.iter_lines(): | |
if line: | |
decoded_line = line.decode('utf-8').strip() | |
if decoded_line == "data: [DONE]": | |
return assistant_response | |
elif decoded_line.startswith("data: "): | |
decoded_line = decoded_line.removeprefix("data: ") | |
chunk_data = json.loads(decoded_line) | |
content = chunk_data['choices'][0]['delta']['content'] | |
assistant_response += content | |
except (json.JSONDecodeError, KeyError) as e: | |
print(f"An error occurred while processing the stream: {e}") | |
return assistant_response | |
def get_streamed_response(message: str, history: list) -> str: | |
"""Main function to interact with the chat API.""" | |
global all_message | |
# Validate input | |
if not validate_input(message): | |
return "Invalid input." | |
# Prepare the message history | |
all_message.append({"role": "user", "content": message}) | |
# Retrieve the API token | |
api_key = get_token() | |
if not api_key: | |
return "Unable to retrieve the API key." | |
# Set up the headers with the API key | |
headers = HEADERS.copy() | |
headers["Authorization"] = f"Bearer {api_key}" | |
# Prepare the payload for the API request | |
payload = { | |
"model": MODEL_NAME, | |
"temperature": 1.1, | |
"top_p": 0.9, | |
"top_k": 50, | |
"repetition_penalty": 1, | |
"n": 1, | |
"messages": all_message, | |
"stream_tokens": True, | |
} | |
# Send the request and process the stream | |
response = post_request(payload, headers) | |
if response: | |
assistant_response = process_stream(response) | |
all_message.append({"role": "assistant", "content": assistant_response}) | |
return assistant_response | |
else: | |
return "Failed to get a response from the API." | |
# Launch the Gradio interface | |
gr.ChatInterface(fn=get_streamed_response, title="TherapistGPT", description="...", retry_btn="Regenerate 🔁").launch() |