Spaces:
Running
Running
import os | |
import json | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
import openai | |
from dotenv import load_dotenv | |
from typing import List | |
# Load environment variables | |
load_dotenv() | |
# Initialize OpenAI client | |
client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:3000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class Message(BaseModel): | |
content: str | |
role: str | |
class ChatRequest(BaseModel): | |
messages: List[Message] | |
async def stream_text(messages: List[Message]): | |
try: | |
formatted_messages = [ | |
{"role": "system", "content": """You are an AI learning assistant for PlayGo AI, | |
an educational platform. Your goal is to help students learn and understand various | |
subjects. Provide clear, concise, and accurate explanations."""}, | |
] + [{"role": msg.role, "content": msg.content} for msg in messages] | |
stream = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=formatted_messages, | |
temperature=0.7, | |
stream=True | |
) | |
for chunk in stream: | |
for choice in chunk.choices: | |
if choice.finish_reason == "stop": | |
continue | |
else: | |
yield '0:{text}\n'.format(text=json.dumps(choice.delta.content)) | |
if chunk.choices == []: | |
usage = chunk.usage | |
prompt_tokens = usage.prompt_tokens | |
completion_tokens = usage.completion_tokens | |
yield 'd:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}}}}\n'.format( | |
reason="stop", | |
prompt=prompt_tokens, | |
completion=completion_tokens | |
) | |
except Exception as e: | |
print(f"Error in stream_text: {str(e)}") | |
yield f"Error: {str(e)}".encode('utf-8') | |
async def landing_page_chat(request: ChatRequest): | |
response = StreamingResponse( | |
stream_text(request.messages), | |
) | |
response.headers['x-vercel-ai-data-stream'] = 'v1' | |
return response | |
async def root(): | |
return {"message": "Hello, World!"} |