ChenyuRabbitLove's picture
feat: add more examples for landing page chatbot
09ee276
raw
history blame
2.5 kB
import os
import json
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
from dotenv import load_dotenv
from typing import List
# Load environment variables
load_dotenv()
# Initialize OpenAI client
client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
content: str
role: str
class ChatRequest(BaseModel):
messages: List[Message]
async def stream_text(messages: List[Message]):
try:
formatted_messages = [
{"role": "system", "content": """You are an AI learning assistant for PlayGo AI,
an educational platform. Your goal is to help students learn and understand various
subjects. Provide clear, concise, and accurate explanations."""},
] + [{"role": msg.role, "content": msg.content} for msg in messages]
stream = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=formatted_messages,
temperature=0.7,
stream=True
)
for chunk in stream:
for choice in chunk.choices:
if choice.finish_reason == "stop":
continue
else:
yield '0:{text}\n'.format(text=json.dumps(choice.delta.content))
if chunk.choices == []:
usage = chunk.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
yield 'd:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}}}}\n'.format(
reason="stop",
prompt=prompt_tokens,
completion=completion_tokens
)
except Exception as e:
print(f"Error in stream_text: {str(e)}")
yield f"Error: {str(e)}".encode('utf-8')
@app.post("/api/landing_page_chat")
async def landing_page_chat(request: ChatRequest):
response = StreamingResponse(
stream_text(request.messages),
)
response.headers['x-vercel-ai-data-stream'] = 'v1'
return response
@app.get("/api/hello")
async def root():
return {"message": "Hello, World!"}