Spaces:
Sleeping
Sleeping
from langchain_openai import AzureChatOpenAI | |
from msal import ConfidentialClientApplication | |
from langchain_openai import AzureOpenAIEmbeddings | |
from langchain.vectorstores.azuresearch import AzureSearch | |
import os | |
class LLM: | |
def __init__(self, llm): | |
self.llm = llm | |
self.callbacks = [] | |
def stream(self, prompt, prompt_arguments): | |
self.llm.streaming = True | |
streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments)) | |
output = "" | |
for op in streamed_content: | |
output += op.content | |
yield output | |
def get_prediction(self, prompt, prompt_arguments): | |
self.llm.callbacks = self.callbacks | |
return self.llm.predict_messages( | |
prompt.format_messages(**prompt_arguments) | |
).content | |
async def get_aprediction(self, prompt, prompt_arguments): | |
self.llm.callbacks = self.callbacks | |
prediction = await self.llm.apredict_messages( | |
prompt.format_messages(**prompt_arguments) | |
) | |
return prediction | |
async def get_apredictions(self, prompts, prompts_arguments): | |
self.llm.callbacks = self.callbacks | |
predictions = [] | |
for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments): | |
prediction = await self.llm.apredict_messages( | |
prompts[prompt_].format_messages(**prompt_args_) | |
) | |
predictions.append(prediction.content) | |
return predictions | |
def get_token() -> str | None: | |
app = ConfidentialClientApplication( | |
client_id=os.getenv("CLIENT_ID"), | |
client_credential=os.getenv("CLIENT_SECRET"), | |
authority=f"https://login.microsoftonline.com/{os.getenv('TENANT_ID')}", | |
) | |
result = app.acquire_token_for_client(scopes=[os.getenv("SCOPE")]) | |
if result is not None: | |
return result["access_token"] | |
def get_llm(): | |
os.environ["OPENAI_API_KEY"] = get_token() | |
os.environ["AZURE_OPENAI_ENDPOINT"] = ( | |
f"{os.getenv('OPENAI_API_ENDPOINT')}{os.getenv('DEPLOYMENT_ID')}/chat/completions?api-version={os.getenv('OPENAI_API_VERSION')}" | |
) | |
return LLM(AzureChatOpenAI()) | |
def get_vectorstore(index_name, model="text-embedding-ada-002"): | |
os.environ["AZURE_OPENAI_ENDPOINT"] = ( | |
f"{os.getenv('OPENAI_API_ENDPOINT')}{os.getenv('DEPLOYMENT_EMB_ID')}/embeddings?api-version={os.getenv('OPENAI_API_VERSION')}" | |
) | |
os.environ["AZURE_OPENAI_API_KEY"] = get_token() | |
aoai_embeddings = AzureOpenAIEmbeddings( | |
azure_deployment=model, | |
openai_api_version=os.getenv("OPENAI_API_VERSION"), | |
) | |
vector_store: AzureSearch = AzureSearch( | |
azure_search_endpoint=os.getenv("VECTOR_STORE_ADDRESS"), | |
azure_search_key=os.getenv("VECTOR_STORE_PASSWORD"), | |
index_name=index_name, | |
embedding_function=aoai_embeddings.embed_query, | |
) | |
return vector_store | |