File size: 3,005 Bytes
579d749
ae465d3
 
579d749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07b5e8
ae465d3
 
 
 
 
 
 
 
 
 
d3fe3f2
ae465d3
 
1da05bb
579d749
ae465d3
 
 
 
f07b5e8
ae465d3
d3fe3f2
ae465d3
 
71d9aa6
 
 
 
 
ae465d3
 
 
 
71d9aa6
 
 
ae465d3
 
71d9aa6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from msal import ConfidentialClientApplication
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
from langchain_groq import ChatGroq
from langchain.vectorstores.azuresearch import AzureSearch
import os


class LLM:
    def __init__(self, llm):
        self.llm = llm
        self.callbacks = []

    def stream(self, prompt, prompt_arguments):
        self.llm.streaming = True
        streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments))
        output = ""
        for op in streamed_content:
            output += op.content
            yield output

    def get_prediction(self, prompt, prompt_arguments):
        self.llm.callbacks = self.callbacks
        return self.llm.predict_messages(
            prompt.format_messages(**prompt_arguments)
        ).content

    async def get_aprediction(self, prompt, prompt_arguments):
        self.llm.callbacks = self.callbacks
        prediction = await self.llm.apredict_messages(
            prompt.format_messages(**prompt_arguments)
        )
        return prediction

    async def get_apredictions(self, prompts, prompts_arguments):
        self.llm.callbacks = self.callbacks
        predictions = []
        for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments):
            prediction = await self.llm.apredict_messages(
                prompts[prompt_].format_messages(**prompt_args_)
            )
            predictions.append(prediction.content)
        return predictions


def get_llm_api(groq_model_name):
    if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"):
        print("Using Azure OpenAI API")
        return LLM(
            AzureChatOpenAI(
                deployment_name=os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"),
                openai_api_key=os.getenv("EKI_OPENAI_API_KEY"),
                azure_endpoint=os.getenv("EKI_OPENAI_LLM_API_ENDPOINT"),
                openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"),
                streaming=True,
                temperature=0,
                max_tokens=2048,
                stop=["<|im_end|>"],
            )
        )

    else:
        print("Using GROQ API")
        return LLM(
            ChatGroq(
                model=groq_model_name,
                temperature=0,
                max_tokens=2048,
            )
        )


def get_vectorstore_api(index_name):
    aoai_embeddings = AzureOpenAIEmbeddings(
        model="text-embedding-ada-002",
        azure_deployment=os.getenv("EKI_OPENAI_EMB_DEPLOYMENT_NAME"),
        api_key=os.getenv("EKI_OPENAI_API_KEY"),
        azure_endpoint=os.environ["EKI_OPENAI_EMB_API_ENDPOINT"],
        openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"),
    )

    vector_store: AzureSearch = AzureSearch(
        azure_search_endpoint=os.getenv("EKI_VECTOR_STORE_ADDRESS"),
        azure_search_key=os.getenv("EKI_VECTOR_STORE_PASSWORD"),
        index_name=index_name,
        embedding_function=aoai_embeddings.embed_query,
    )

    return vector_store