File size: 2,611 Bytes
fccc18d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import json
import tomli
from string import Template

from fastapi import FastAPI, Request, HTTPException
from fastapi import APIRouter

from google.genai import types
from google import genai

from . import prev_summaries

router = APIRouter()

GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
client = genai.client.AsyncClient(genai.client.ApiClient(api_key=GOOGLE_API_KEY))

@router.post("/gemini_summary")
async def gemini_summary(request: Request):
    try:
        body = await request.json()
    except Exception as e:
        raise HTTPException(status_code=400, detail="Invalid JSON payload") from e

    conversation = body.get("conversation")
    if not conversation:
        raise HTTPException(status_code=400, detail="Missing 'conversation' in payload")

    print("--------------------------------")
    print(body)
    print()
    temperature = body.get("temperature", 0.7)
    max_tokens = body.get("max_tokens", 256)
    model = body.get("model", "gemini-1.5-flash")
    
    # Get session ID from the request
    session_id = request.headers.get("X-Session-ID")
    if not session_id:
        raise HTTPException(status_code=400, detail="Missing 'session_id' in payload")

    if session_id not in prev_summaries:
        prev_summaries[session_id] = ""

    prev_summary = prev_summaries[session_id]

    with open("../../configs/prompts.toml", "rb") as f:
        prompts = tomli.load(f)

    prompt = Template(prompts["summarization"]["prompt"])
    system_prompt = Template(prompts["summarization"]["system_prompt"])

    latest_conversations = conversation[-2:]

    for i, latest_conversation in enumerate(latest_conversations):
        # if "attachments" in latest_conversation:
        if "attachments" in latest_conversation:
            del latest_conversation["attachments"]
        if "sessionId" in latest_conversation:
            del latest_conversation["sessionId"]
        latest_conversations[i] = latest_conversation

    summary = await client.models.generate_content(
        model=model,
        contents=[
            prompt.safe_substitute(
                previous_summary=prev_summary,
                latest_conversation=str(latest_conversations)
            )
        ],
        config=types.GenerateContentConfig(
            system_instruction=system_prompt.substitute(persona="professional"), 
            temperature=temperature,
            max_output_tokens=max_tokens,
            top_p=0.95,
        )
    )

    print(summary)
    summary_text = summary.text
    prev_summaries[session_id] = summary_text
    return {"summary": summary_text}