File size: 1,219 Bytes
15aea1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
from enum import Enum

from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama, OllamaEmbeddings

from .embedding import CustomEmbedding


class LLMModel(Enum):
    OLLAMA = "ChatOllama"
    GROQ = "ChatGroq"


def get_llm_model_chat(temperature=0.01, max_tokens: int = None):
    if str(os.getenv("USE_OLLAMA_CHAT")) == "1":
        return ChatOllama(
            model=os.getenv("OLLAMA_MODEL"),
            temperature=temperature,
            num_predict=max_tokens,
        )
    return ChatGroq(
        model=os.getenv("GROQ_MODEL_NAME"),
        temperature=temperature,
        max_tokens=max_tokens,
    )


def get_llm_model_embedding():
    if str(os.getenv("USE_HF_EMBEDDING")) == "1":
        return CustomEmbedding()
    return OllamaEmbeddings(
        model=os.getenv("OLLAM_EMB"),
        base_url=(
            os.getenv("OLLAMA_HOST") if os.getenv("OLLAMA_HOST") is not None else None
        ),
        client_kwargs=(
            {
                "headers": {
                    "Authorization": "Bearer " + (os.getenv("OLLAMA_TOKEN") or "")
                }
            }
            if os.getenv("OLLAMA_HOST") is not None
            else None
        ),
    )