File size: 2,105 Bytes
bd132d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import os
import uuid
from typing import AsyncGenerator, NoReturn

import google.generativeai as genai
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket

load_dotenv()

genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel("gemini-pro")

app = FastAPI()

PROMPT = """
You are a helpful assistant, skilled in explaining complex concepts in simple terms.

{message}
"""  # noqa: E501

IMAGE_PROMPT = """
Generate an image based on the following description:

{description}
"""  # noqa: E501

async def get_ai_response(message: str) -> AsyncGenerator[str, None]:
    """
    Gemini Response
    """
    response = await model.generate_content_async(
        PROMPT.format(message=message), stream=True
    )

    msg_id = str(uuid.uuid4())
    all_text = ""
    async for chunk in response:
        if chunk.candidates:
            for part in chunk.candidates[0].content.parts:
                all_text += part.text
                yield json.dumps({"id": msg_id, "text": all_text})

async def get_ai_image(description: str) -> str:
    """
    Gemini Image Generation
    """
    response = await model.generate_image_async(
        IMAGE_PROMPT.format(description=description)
    )

    if response.images:
        # Assuming we take the first generated image
        return json.dumps({"image_url": response.images[0].url})
    return json.dumps({"error": "No image generated"})

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> NoReturn:
    """
    Websocket for AI responses
    """
    await websocket.accept()
    while True:
        message = await websocket.receive_text()
        async for text in get_ai_response(message):
            await websocket.send_text(text)

@app.post("/generate-image/")
async def generate_image_endpoint(description: str):
    """
    Endpoint for AI image generation
    """
    image_url = await get_ai_image(description)
    return json.loads(image_url)

if __name__ == "__main__":
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860
    )