File size: 2,906 Bytes
579d749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from langchain_openai import AzureChatOpenAI
from msal import ConfidentialClientApplication
from langchain_openai import AzureOpenAIEmbeddings
from langchain.vectorstores.azuresearch import AzureSearch
import os


class LLM:
    def __init__(self, llm):
        self.llm = llm
        self.callbacks = []

    def stream(self, prompt, prompt_arguments):
        self.llm.streaming = True
        streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments))
        output = ""
        for op in streamed_content:
            output += op.content
            yield output

    def get_prediction(self, prompt, prompt_arguments):
        self.llm.callbacks = self.callbacks
        return self.llm.predict_messages(
            prompt.format_messages(**prompt_arguments)
        ).content

    async def get_aprediction(self, prompt, prompt_arguments):
        self.llm.callbacks = self.callbacks
        prediction = await self.llm.apredict_messages(
            prompt.format_messages(**prompt_arguments)
        )
        return prediction

    async def get_apredictions(self, prompts, prompts_arguments):
        self.llm.callbacks = self.callbacks
        predictions = []
        for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments):
            prediction = await self.llm.apredict_messages(
                prompts[prompt_].format_messages(**prompt_args_)
            )
            predictions.append(prediction.content)
        return predictions


def get_token() -> str | None:
    app = ConfidentialClientApplication(
        client_id=os.getenv("CLIENT_ID"),
        client_credential=os.getenv("CLIENT_SECRET"),
        authority=f"https://login.microsoftonline.com/{os.getenv('TENANT_ID')}",
    )
    result = app.acquire_token_for_client(scopes=[os.getenv("SCOPE")])
    if result is not None:
        return result["access_token"]


def get_llm():
    os.environ["OPENAI_API_KEY"] = get_token()
    os.environ["AZURE_OPENAI_ENDPOINT"] = (
        f"{os.getenv('OPENAI_API_ENDPOINT')}{os.getenv('DEPLOYMENT_ID')}/chat/completions?api-version={os.getenv('OPENAI_API_VERSION')}"
    )

    return LLM(AzureChatOpenAI())


def get_vectorstore(index_name, model="text-embedding-ada-002"):
    os.environ["AZURE_OPENAI_ENDPOINT"] = (
        f"{os.getenv('OPENAI_API_ENDPOINT')}{os.getenv('DEPLOYMENT_EMB_ID')}/embeddings?api-version={os.getenv('OPENAI_API_VERSION')}"
    )
    os.environ["AZURE_OPENAI_API_KEY"] = get_token()

    aoai_embeddings = AzureOpenAIEmbeddings(
        azure_deployment=model,
        openai_api_version=os.getenv("OPENAI_API_VERSION"),
    )

    vector_store: AzureSearch = AzureSearch(
        azure_search_endpoint=os.getenv("VECTOR_STORE_ADDRESS"),
        azure_search_key=os.getenv("VECTOR_STORE_PASSWORD"),
        index_name=index_name,
        embedding_function=aoai_embeddings.embed_query,
    )

    return vector_store