File size: 1,491 Bytes
7bd11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Optional

from langchain_core.prompts import PromptTemplate
from langchain_google_vertexai import ChatVertexAI
from pydantic import BaseModel, ConfigDict
from vertexai.generative_models import HarmCategory, HarmBlockThreshold


class VertexAIModelConfig(BaseModel):
    model_config = ConfigDict()
    model_config["protected_namespaces"] = ()
    prompt_template: str
    model_kwargs: dict = {}


class VertexAIModel:
    def __init__(self, config: VertexAIModelConfig):
        self.config = config
        self._model = None

    @property
    def model(self):
        return ChatVertexAI(**self.config.model_kwargs,
                            safety_settings={
                                HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
                                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
                            })

    @property
    def prompt(self) -> Optional[PromptTemplate]:
        if self.config.prompt_template:
            return PromptTemplate(
                input_variables=["context", "question"], template=self.config.prompt_template
            )