|
from msal import ConfidentialClientApplication |
|
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI |
|
from langchain_groq import ChatGroq |
|
from langchain.vectorstores.azuresearch import AzureSearch |
|
import os |
|
|
|
|
|
class LLM: |
|
def __init__(self, llm): |
|
self.llm = llm |
|
self.callbacks = [] |
|
|
|
def stream(self, prompt, prompt_arguments): |
|
self.llm.streaming = True |
|
streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments)) |
|
output = "" |
|
for op in streamed_content: |
|
output += op.content |
|
yield output |
|
|
|
def get_prediction(self, prompt, prompt_arguments): |
|
self.llm.callbacks = self.callbacks |
|
return self.llm.predict_messages( |
|
prompt.format_messages(**prompt_arguments) |
|
).content |
|
|
|
async def get_aprediction(self, prompt, prompt_arguments): |
|
self.llm.callbacks = self.callbacks |
|
prediction = await self.llm.apredict_messages( |
|
prompt.format_messages(**prompt_arguments) |
|
) |
|
return prediction |
|
|
|
async def get_apredictions(self, prompts, prompts_arguments): |
|
self.llm.callbacks = self.callbacks |
|
predictions = [] |
|
for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments): |
|
prediction = await self.llm.apredict_messages( |
|
prompts[prompt_].format_messages(**prompt_args_) |
|
) |
|
predictions.append(prediction.content) |
|
return predictions |
|
|
|
|
|
def get_llm_api(groq_model_name): |
|
if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"): |
|
print("Using Azure OpenAI API") |
|
return LLM( |
|
AzureChatOpenAI( |
|
deployment_name=os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"), |
|
openai_api_key=os.getenv("EKI_OPENAI_API_KEY"), |
|
azure_endpoint=os.getenv("EKI_OPENAI_LLM_API_ENDPOINT"), |
|
openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"), |
|
streaming=True, |
|
temperature=0, |
|
max_tokens=2048, |
|
stop=["<|im_end|>"], |
|
) |
|
) |
|
|
|
else: |
|
print("Using GROQ API") |
|
return LLM( |
|
ChatGroq( |
|
model=groq_model_name, |
|
temperature=0, |
|
max_tokens=2048, |
|
) |
|
) |
|
|
|
|
|
def get_vectorstore_api(index_name): |
|
aoai_embeddings = AzureOpenAIEmbeddings( |
|
model="text-embedding-ada-002", |
|
azure_deployment=os.getenv("EKI_OPENAI_EMB_DEPLOYMENT_NAME"), |
|
api_key=os.getenv("EKI_OPENAI_API_KEY"), |
|
azure_endpoint=os.environ["EKI_OPENAI_EMB_API_ENDPOINT"], |
|
openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"), |
|
) |
|
|
|
vector_store: AzureSearch = AzureSearch( |
|
azure_search_endpoint=os.getenv("EKI_VECTOR_STORE_ADDRESS"), |
|
azure_search_key=os.getenv("EKI_VECTOR_STORE_PASSWORD"), |
|
index_name=index_name, |
|
embedding_function=aoai_embeddings.embed_query, |
|
) |
|
|
|
return vector_store |
|
|