File size: 3,579 Bytes
01c75ae
eca9523
01c75ae
7c4ae4d
eca9523
01c75ae
 
 
7c4ae4d
 
 
01c75ae
7c4ae4d
 
 
 
 
 
 
01c75ae
 
7c4ae4d
918a703
 
 
01c75ae
 
 
918a703
c207609
7c4ae4d
918a703
 
 
 
 
 
 
 
 
 
 
7c4ae4d
 
eca9523
7c4ae4d
 
 
eca9523
7c4ae4d
 
 
 
 
 
 
 
 
918a703
7c4ae4d
 
 
 
 
 
 
918a703
 
 
 
01c75ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from typing import List, Tuple
from flask import Flask, request, jsonify
from google.cloud import vertex_ai  # Ensure to install the Google Cloud SDK (vertex-ai)

# Initialize Flask app
app = Flask(__name__)

# Set the Google Cloud project ID and location (Make sure to replace with your own)
project_id = os.getenv("GOOGLE_CLOUD_PROJECT_ID")  # Make sure to set this in your environment
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")  # Default location if not set

# Initialize Vertex AI client
vertex_ai_client = vertex_ai.PredictionServiceClient(client_options={"api_endpoint": f"{location}-aiplatform.googleapis.com"})

# Define the endpoint for your model deployment
endpoint = "projects/{project_id}/locations/{location}/endpoints/{endpoint_id}"  # Replace with your actual endpoint ID

# Define a system message (if necessary)
SYSTEM_MESSAGE = "You are a helpful assistant."

# Function to generate AI response using Google Gemini (Vertex AI)
def generate_response(
    user_input: str, 
    history: List[Tuple[str, str]], 
    max_tokens: int = 150, 
    temperature: float = 0.7, 
    top_p: float = 1.0
) -> str:
    """
    Generates a response using the Google Gemini (Vertex AI) API.
    Args:
        user_input: The user's input message.
        history: A list of tuples containing the conversation history 
                 (user input, AI response).
        max_tokens: The maximum number of tokens in the generated response.
        temperature: Controls the randomness of the generated response.
        top_p: Controls the nucleus sampling probability.
    Returns:
        str: The generated response from the AI model.
    """
    try:
        # Prepare the history and current input for the model
        conversation = [{"role": "system", "content": SYSTEM_MESSAGE}]
        for user_message, assistant_message in history:
            conversation.append({"role": "user", "content": user_message})
            conversation.append({"role": "assistant", "content": assistant_message})

        # Add the current user input
        conversation.append({"role": "user", "content": user_input})

        # Prepare the payload for the request to Vertex AI
        instances = [{"content": conversation}]
        parameters = {
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": top_p,
        }

        # Send the request to the Vertex AI API
        response = vertex_ai_client.predict(endpoint=endpoint, instances=instances, parameters=parameters)

        # Extract the response from the API output
        ai_response = response.predictions[0].get('content', 'Sorry, I couldn’t generate a response.')
        
        return ai_response

    except Exception as e:
        print(f"An error occurred: {e}")
        return "Error: An unexpected error occurred while processing your request."

# Route to handle user input and generate responses
@app.route("/chat", methods=["POST"])
def chat():
    try:
        # Get user input from the request
        user_input = request.json.get("user_input", "")
        history = request.json.get("history", [])
        
        # Generate the AI response
        response = generate_response(
            user_input=user_input, 
            history=history
        )
        
        # Return the response as JSON
        return jsonify({"response": response})

    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    # Run the app
    app.run(debug=True, host="0.0.0.0", port=5000)