zenn / app.py
nakamura196's picture
chore: minor update
17b121a
import gradio as gr
from openai import AzureOpenAI
import os
from dotenv import load_dotenv
import time
def load_environment():
"""Load environment variables."""
load_dotenv(override=True)
def initialize_openai_client():
"""Initialize the Azure OpenAI client."""
return AzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-10-01-preview"
)
def create_assistant(client, vector_store_id):
"""Create an assistant with specified configuration."""
return client.beta.assistants.create(
model="gpt-4o",
instructions="ๆŒ‡็คบใŒใชใ„้™ใ‚Šใ€ๆ—ฅๆœฌ่ชžใงๅ›ž็ญ”ใ—ใฆใใ ใ•ใ„ใ€‚",
tools=[{
"type": "file_search",
"file_search": {"ranking_options": {"ranker": "default_2024_08_21", "score_threshold": 0}}
}],
tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}},
temperature=0
)
def create_thread():
"""Create a new thread."""
return client.beta.threads.create()
def clear_thread(state):
"""ใ‚ปใƒƒใ‚ทใƒงใƒณใ‚’ใƒชใ‚ปใƒƒใƒˆใ—ใ€ใƒใƒฃใƒƒใƒˆๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ขใ™ใ‚‹ใ€‚"""
state = initialize_session() # ๆ–ฐใ—ใ„ใ‚นใƒฌใƒƒใƒ‰ใ‚’็”Ÿๆˆ
return [], ""
def get_annotations(msg):
annotations = msg.content[0].text.annotations
file_ids = []
if annotations:
for annotation in annotations:
file_id = annotation.file_citation.file_id
if file_id in file_ids:
continue
print("file_id", file_id)
cited_file = client.files.retrieve(file_id)
print("filename", cited_file.filename)
try:
content = client.files.content(file_id)
except Exception as e:
print(e)
pass
file_ids.append(file_id)
return file_ids
def get_chatbot_response(client, thread_id, assistant_id, message):
"""Get chatbot response for a given message."""
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=message # Ensure the content is an object with a `text` key
)
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id
)
while run.status in ["queued", "in_progress", "cancelling"]:
time.sleep(1)
run = client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run.id
)
if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=thread_id)
for msg in messages:
# file_ids = get_annotations(msg)
main_text = msg.content[0].text.value
# main_text += "\n> aaa"
return main_text
elif run.status == "requires_action":
# Handle cases where the assistant requires further action
pass
return "Unable to retrieve a response." # Fallback response
def chatbot_response(history, message):
"""Wrapper function to generate chatbot response."""
global thread
# Get response from the API
assistant_response = get_chatbot_response(client, thread.id, assistant.id, message)
# Update chat history
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": assistant_response})
return history, ""
# Load environment variables
load_environment()
client = initialize_openai_client()
vector_store_id = os.getenv("AZURE_OPENAI_VECTOR_STORE_ID")
assistant = create_assistant(client, vector_store_id)
def respond(message, chat_history, state):
"""ใƒใƒฃใƒƒใƒˆๅฑฅๆญดใจ็Šถๆ…‹ใ‚’ๆ›ดๆ–ฐใ™ใ‚‹ใ€‚"""
thread_id = state["thread_id"]
bot_message = get_chatbot_response(client, thread_id, assistant.id, message)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": bot_message})
return "", chat_history
def initialize_session():
"""ใ‚ปใƒƒใ‚ทใƒงใƒณใ”ใจใซ็‹ฌ็ซ‹ใ—ใŸใ‚นใƒฌใƒƒใƒ‰ใ‚’ๅˆๆœŸๅŒ–ใ™ใ‚‹ใ€‚"""
thread = create_thread()
return {"thread_id": thread.id}
with gr.Blocks() as demo:
gr.Markdown("""
# Azure OpenAI Assistants API x Gradio x Zenn
This is a Gradio demo of Retrieval-Augmented Generation (RAG) using the Azure OpenAI Assistants API, applied to [Zenn articles](https://zenn.dev/nakamura196).
""")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox(placeholder="ใ“ใ“ใซใƒกใƒƒใ‚ปใƒผใ‚ธใ‚’ๅ…ฅๅŠ›ใ—ใฆใใ ใ•ใ„...")
state = gr.State(initialize_session) # ใ‚ปใƒƒใ‚ทใƒงใƒณใ”ใจใฎ็Šถๆ…‹ใ‚’ๅˆๆœŸๅŒ–
clear = gr.Button("Clear")
msg.submit(respond, [msg, chatbot, state], [msg, chatbot])
clear.click(clear_thread, inputs=[state], outputs=[chatbot, msg])
if __name__ == "__main__":
demo.launch()