File size: 1,306 Bytes
316b9d4
420fa8a
 
 
316b9d4
 
 
420fa8a
c59035e
d5b3118
420fa8a
316b9d4
 
 
d5b3118
 
 
 
 
 
 
 
 
420fa8a
316b9d4
d5b3118
 
 
 
420fa8a
d5b3118
 
 
316b9d4
420fa8a
316b9d4
c526665
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from typing import Optional
import sys

import vertexai
from dotenv import load_dotenv

sys.path.append("../")
from setup.vertex_ai_setup import initialize_vertexai_params
from vertexai import generative_models

load_dotenv()
VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECT"]

DEFAULT_GEN_CONFIG = {
    "temperature": 0.49,
    "max_output_tokens": 1024,
}

DEFAULT_SAFETY_SETTINGS = {
    generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
}


def get_gemini_response(prompt_text, model, generation_config: Optional[dict] = None,
                        safety_settings: Optional[dict] = None) -> str:
    initialize_vertexai_params()
    if model is None:
        model = "gemini-1.0-pro"
    model = generative_models.GenerativeModel(model,
                                              generation_config=DEFAULT_GEN_CONFIG if generation_config is None else generation_config,
                                              safety_settings=DEFAULT_SAFETY_SETTINGS if safety_settings is None else safety_settings)

    model_response = model.generate_content(prompt_text)

    return model_response.text