JohnPorkEater's picture
Update app.py
7a360f2 verified
raw
history blame
3.54 kB
import gradio as gr
import os
import requests
import json
from typing import Optional
# Define constants
TOGETHER_API_URL = "https://api.together.xyz/v1/chat/completions"
MODEL_NAME = "NousResearch/Nous-Hermes-2-Yi-34B"
HEADERS = {"accept": "application/json", "content-type": "application/json"}
# Initialize message history
all_message = [{"role": "system", "content": "... system message ..."}]
def validate_input(message: str) -> bool:
"""Validate the user input before sending it to the API."""
# Add input validation logic here, such as checking length, content, etc.
return True # Placeholder for actual validation logic
def get_token() -> Optional[str]:
"""Retrieve the API token from the environment variables."""
try:
return os.environ['TOGETHER_API_KEY']
except KeyError:
print("The TOGETHER_API_KEY environment variable is not set.")
return None
def post_request(payload: dict, headers: dict) -> requests.Response:
"""Send a POST request to the API."""
try:
response = requests.post(TOGETHER_API_URL, json=payload, headers=headers, stream=True)
response.raise_for_status()
return response
except requests.exceptions.RequestException as e:
print(f"An error occurred while making the API request: {e}")
return None
def process_stream(response: requests.Response) -> str:
"""Process the streamed response from the API."""
assistant_response = ""
try:
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8').strip()
if decoded_line == "data: [DONE]":
return assistant_response
elif decoded_line.startswith("data: "):
decoded_line = decoded_line.removeprefix("data: ")
chunk_data = json.loads(decoded_line)
content = chunk_data['choices'][0]['delta']['content']
assistant_response += content
except (json.JSONDecodeError, KeyError) as e:
print(f"An error occurred while processing the stream: {e}")
return assistant_response
def get_streamed_response(message: str, history: list) -> str:
"""Main function to interact with the chat API."""
global all_message
# Validate input
if not validate_input(message):
return "Invalid input."
# Prepare the message history
all_message.append({"role": "user", "content": message})
# Retrieve the API token
api_key = get_token()
if not api_key:
return "Unable to retrieve the API key."
# Set up the headers with the API key
headers = HEADERS.copy()
headers["Authorization"] = f"Bearer {api_key}"
# Prepare the payload for the API request
payload = {
"model": MODEL_NAME,
"temperature": 1.1,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1,
"n": 1,
"messages": all_message,
"stream_tokens": True,
}
# Send the request and process the stream
response = post_request(payload, headers)
if response:
assistant_response = process_stream(response)
all_message.append({"role": "assistant", "content": assistant_response})
return assistant_response
else:
return "Failed to get a response from the API."
# Launch the Gradio interface
gr.ChatInterface(fn=get_streamed_response, title="TherapistGPT", description="...", retry_btn="Regenerate 🔁").launch()