Spaces:
Runtime error
Runtime error
from typing import Any, Optional, Dict, List, Union, Callable | |
from pydantic import BaseModel, ConfigDict | |
from phi.assistant.openai.run import Run | |
from phi.assistant.openai.message import Message | |
from phi.assistant.openai.assistant import OpenAIAssistant | |
from phi.assistant.openai.exceptions import ThreadIdNotSet | |
from phi.utils.log import logger | |
try: | |
from openai import OpenAI | |
from openai.types.beta.assistant import Assistant as OpenAIAssistantType | |
from openai.types.beta.thread import Thread as OpenAIThread | |
from openai.types.beta.thread_deleted import ThreadDeleted as OpenAIThreadDeleted | |
except ImportError: | |
logger.error("`openai` not installed") | |
raise | |
class Thread(BaseModel): | |
# -*- Thread settings | |
# Thread id which can be referenced in API endpoints. | |
id: Optional[str] = None | |
# The object type, populated by the API. Always thread. | |
object: Optional[str] = None | |
# OpenAIAssistant used for this thread | |
assistant: Optional[OpenAIAssistant] = None | |
# The ID of the assistant for this thread. | |
assistant_id: Optional[str] = None | |
# Set of 16 key-value pairs that can be attached to an object. | |
# This can be useful for storing additional information about the object in a structured format. | |
# Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. | |
metadata: Optional[Dict[str, Any]] = None | |
# True if this thread is active | |
is_active: bool = True | |
# The Unix timestamp (in seconds) for when the thread was created. | |
created_at: Optional[int] = None | |
openai: Optional[OpenAI] = None | |
openai_thread: Optional[OpenAIThread] = None | |
openai_assistant: Optional[OpenAIAssistantType] = None | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
def client(self) -> OpenAI: | |
return self.openai or OpenAI() | |
def messages(self) -> List[Message]: | |
# Returns A list of messages in this thread. | |
try: | |
return self.get_messages() | |
except ThreadIdNotSet: | |
return [] | |
def load_from_openai(self, openai_thread: OpenAIThread): | |
self.id = openai_thread.id | |
self.object = openai_thread.object | |
self.created_at = openai_thread.created_at | |
self.openai_thread = openai_thread | |
def create(self, messages: Optional[List[Union[Message, Dict]]] = None) -> "Thread": | |
request_body: Dict[str, Any] = {} | |
if messages is not None: | |
_messages = [] | |
for _message in messages: | |
if isinstance(_message, Message): | |
_messages.append(_message.to_dict()) | |
else: | |
_messages.append(_message) | |
request_body["messages"] = _messages | |
if self.metadata is not None: | |
request_body["metadata"] = self.metadata | |
self.openai_thread = self.client.beta.threads.create(**request_body) | |
self.load_from_openai(self.openai_thread) | |
logger.debug(f"Thread created: {self.id}") | |
return self | |
def get_id(self) -> Optional[str]: | |
return self.id or self.openai_thread.id if self.openai_thread else None | |
def get_from_openai(self) -> OpenAIThread: | |
_thread_id = self.get_id() | |
if _thread_id is None: | |
raise ThreadIdNotSet("Thread.id not set") | |
self.openai_thread = self.client.beta.threads.retrieve( | |
thread_id=_thread_id, | |
) | |
self.load_from_openai(self.openai_thread) | |
return self.openai_thread | |
def get(self, use_cache: bool = True) -> "Thread": | |
if self.openai_thread is not None and use_cache: | |
return self | |
self.get_from_openai() | |
return self | |
def get_or_create(self, use_cache: bool = True, messages: Optional[List[Union[Message, Dict]]] = None) -> "Thread": | |
try: | |
return self.get(use_cache=use_cache) | |
except ThreadIdNotSet: | |
return self.create(messages=messages) | |
def update(self) -> "Thread": | |
try: | |
thread_to_update = self.get_from_openai() | |
if thread_to_update is not None: | |
request_body: Dict[str, Any] = {} | |
if self.metadata is not None: | |
request_body["metadata"] = self.metadata | |
self.openai_thread = self.client.beta.threads.update( | |
thread_id=thread_to_update.id, | |
**request_body, | |
) | |
self.load_from_openai(self.openai_thread) | |
logger.debug(f"Thead updated: {self.id}") | |
return self | |
raise ValueError("Thread not available") | |
except ThreadIdNotSet: | |
logger.warning("Thread not available") | |
raise | |
def delete(self) -> OpenAIThreadDeleted: | |
try: | |
thread_to_delete = self.get_from_openai() | |
if thread_to_delete is not None: | |
deletion_status = self.client.beta.threads.delete( | |
thread_id=thread_to_delete.id, | |
) | |
logger.debug(f"Thread deleted: {self.id}") | |
return deletion_status | |
except ThreadIdNotSet: | |
logger.warning("Thread not available") | |
raise | |
def add_message(self, message: Union[Message, Dict]) -> None: | |
try: | |
message = message if isinstance(message, Message) else Message(**message) | |
except Exception as e: | |
logger.error(f"Error creating Message: {e}") | |
raise | |
message.thread_id = self.id | |
message.create() | |
def add(self, messages: List[Union[Message, Dict]]) -> None: | |
existing_thread = self.get_id() is not None | |
if existing_thread: | |
for message in messages: | |
self.add_message(message=message) | |
else: | |
self.create(messages=messages) | |
def run( | |
self, | |
message: Optional[Union[str, Message]] = None, | |
assistant: Optional[OpenAIAssistant] = None, | |
assistant_id: Optional[str] = None, | |
run: Optional[Run] = None, | |
wait: bool = True, | |
callback: Optional[Callable] = None, | |
) -> Run: | |
if message is not None: | |
if isinstance(message, str): | |
message = Message(role="user", content=message) | |
self.add(messages=[message]) | |
try: | |
_thread_id = self.get_id() | |
if _thread_id is None: | |
_thread_id = self.get_from_openai().id | |
except ThreadIdNotSet: | |
logger.error("Thread not available") | |
raise | |
_assistant = assistant or self.assistant | |
_assistant_id = assistant_id or self.assistant_id | |
_run = run or Run() | |
return _run.run( | |
thread_id=_thread_id, assistant=_assistant, assistant_id=_assistant_id, wait=wait, callback=callback | |
) | |
def get_messages(self) -> List[Message]: | |
try: | |
_thread_id = self.get_id() | |
if _thread_id is None: | |
_thread_id = self.get_from_openai().id | |
except ThreadIdNotSet: | |
logger.warning("Thread not available") | |
raise | |
thread_messages = self.client.beta.threads.messages.list( | |
thread_id=_thread_id, | |
) | |
return [Message.from_openai(message=message) for message in thread_messages] | |
def to_dict(self) -> Dict[str, Any]: | |
return self.model_dump(exclude_none=True, include={"id", "object", "messages", "metadata"}) | |
def pprint(self): | |
"""Pretty print using rich""" | |
from rich.pretty import pprint | |
pprint(self.to_dict()) | |
def print_messages(self) -> None: | |
from rich.table import Table | |
from rich.box import ROUNDED | |
from rich.markdown import Markdown | |
from phi.cli.console import console | |
# Get the messages from the thread | |
messages = self.get_messages() | |
# Print the response | |
table = Table( | |
box=ROUNDED, | |
border_style="blue", | |
expand=True, | |
) | |
for m in messages[::-1]: | |
if m.role == "user": | |
table.add_column("User") | |
table.add_column(m.get_content_with_files()) | |
elif m.role == "assistant": | |
table.add_row("OpenAIAssistant", Markdown(m.get_content_with_files())) | |
table.add_section() | |
else: | |
table.add_row(m.role, Markdown(m.get_content_with_files())) | |
table.add_section() | |
console.print(table) | |
def print_response( | |
self, message: str, assistant: OpenAIAssistant, current_message_only: bool = False, markdown: bool = False | |
) -> None: | |
from rich.progress import Progress, SpinnerColumn, TextColumn | |
with Progress(SpinnerColumn(spinner_name="dots"), TextColumn("{task.description}"), transient=True) as progress: | |
progress.add_task("Working...") | |
self.run( | |
message=message, | |
assistant=assistant, | |
wait=True, | |
) | |
if current_message_only: | |
response_messages = [] | |
for m in self.messages: | |
if m.role == "assistant": | |
response_messages.append(m) | |
elif m.role == "user" and m.get_content_text() == message: | |
break | |
total_messages = len(response_messages) | |
for idx, response_message in enumerate(response_messages[::-1], start=1): | |
response_message.pprint( | |
title=f"[bold] :robot: OpenAIAssistant ({idx}/{total_messages}) [/bold]", markdown=markdown | |
) | |
else: | |
for m in self.messages[::-1]: | |
m.pprint(markdown=markdown) | |
def __str__(self) -> str: | |
import json | |
return json.dumps(self.to_dict(), indent=4) | |