Spaces:
Sleeping
Sleeping
import gradio as gr | |
from openai import AzureOpenAI | |
import os | |
from dotenv import load_dotenv | |
import time | |
def load_environment(): | |
"""Load environment variables.""" | |
load_dotenv(override=True) | |
def initialize_openai_client(): | |
"""Initialize the Azure OpenAI client.""" | |
return AzureOpenAI( | |
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | |
api_key=os.getenv("AZURE_OPENAI_API_KEY"), | |
api_version="2024-10-01-preview" | |
) | |
def create_assistant(client, vector_store_id): | |
"""Create an assistant with specified configuration.""" | |
return client.beta.assistants.create( | |
model="gpt-4o", | |
instructions="ๆ็คบใใชใ้ใใๆฅๆฌ่ชใงๅ็ญใใฆใใ ใใใ", | |
tools=[{ | |
"type": "file_search", | |
"file_search": {"ranking_options": {"ranker": "default_2024_08_21", "score_threshold": 0}} | |
}], | |
tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}}, | |
temperature=0 | |
) | |
def create_thread(): | |
"""Create a new thread.""" | |
return client.beta.threads.create() | |
def clear_thread(state): | |
"""ใปใใทใงใณใใชใปใใใใใใฃใใๅฑฅๆญดใใฏใชใขใใใ""" | |
state = initialize_session() # ๆฐใใในใฌใใใ็ๆ | |
return [], "" | |
def get_annotations(msg): | |
annotations = msg.content[0].text.annotations | |
file_ids = [] | |
if annotations: | |
for annotation in annotations: | |
file_id = annotation.file_citation.file_id | |
if file_id in file_ids: | |
continue | |
print("file_id", file_id) | |
cited_file = client.files.retrieve(file_id) | |
print("filename", cited_file.filename) | |
try: | |
content = client.files.content(file_id) | |
except Exception as e: | |
print(e) | |
pass | |
file_ids.append(file_id) | |
return file_ids | |
def get_chatbot_response(client, thread_id, assistant_id, message): | |
"""Get chatbot response for a given message.""" | |
client.beta.threads.messages.create( | |
thread_id=thread_id, | |
role="user", | |
content=message # Ensure the content is an object with a `text` key | |
) | |
run = client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=assistant_id | |
) | |
while run.status in ["queued", "in_progress", "cancelling"]: | |
time.sleep(1) | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread_id, | |
run_id=run.id | |
) | |
if run.status == "completed": | |
messages = client.beta.threads.messages.list(thread_id=thread_id) | |
for msg in messages: | |
# file_ids = get_annotations(msg) | |
main_text = msg.content[0].text.value | |
# main_text += "\n> aaa" | |
return main_text | |
elif run.status == "requires_action": | |
# Handle cases where the assistant requires further action | |
pass | |
return "Unable to retrieve a response." # Fallback response | |
def chatbot_response(history, message): | |
"""Wrapper function to generate chatbot response.""" | |
global thread | |
# Get response from the API | |
assistant_response = get_chatbot_response(client, thread.id, assistant.id, message) | |
# Update chat history | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": assistant_response}) | |
return history, "" | |
# Load environment variables | |
load_environment() | |
client = initialize_openai_client() | |
vector_store_id = os.getenv("AZURE_OPENAI_VECTOR_STORE_ID") | |
assistant = create_assistant(client, vector_store_id) | |
def respond(message, chat_history, state): | |
"""ใใฃใใๅฑฅๆญดใจ็ถๆ ใๆดๆฐใใใ""" | |
thread_id = state["thread_id"] | |
bot_message = get_chatbot_response(client, thread_id, assistant.id, message) | |
chat_history.append({"role": "user", "content": message}) | |
chat_history.append({"role": "assistant", "content": bot_message}) | |
return "", chat_history | |
def initialize_session(): | |
"""ใปใใทใงใณใใจใซ็ฌ็ซใใในใฌใใใๅๆๅใใใ""" | |
thread = create_thread() | |
return {"thread_id": thread.id} | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Azure OpenAI Assistants API x Gradio x Zenn | |
This is a Gradio demo of Retrieval-Augmented Generation (RAG) using the Azure OpenAI Assistants API, applied to [Zenn articles](https://zenn.dev/nakamura196). | |
""") | |
chatbot = gr.Chatbot(type="messages") | |
msg = gr.Textbox(placeholder="ใใใซใกใใปใผใธใๅ ฅๅใใฆใใ ใใ...") | |
state = gr.State(initialize_session) # ใปใใทใงใณใใจใฎ็ถๆ ใๅๆๅ | |
clear = gr.Button("Clear") | |
msg.submit(respond, [msg, chatbot, state], [msg, chatbot]) | |
clear.click(clear_thread, inputs=[state], outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() | |