Spaces:
Runtime error
Runtime error
from typing import List, Dict, Optional | |
from phi.api.prompt import sync_prompt_registry_api, sync_prompt_template_api | |
from phi.api.schemas.prompt import ( | |
PromptRegistrySync, | |
PromptTemplatesSync, | |
PromptTemplateSync, | |
PromptRegistrySchema, | |
PromptTemplateSchema, | |
) | |
from phi.prompt.template import PromptTemplate | |
from phi.prompt.exceptions import PromptUpdateException, PromptNotFoundException | |
from phi.utils.log import logger | |
class PromptRegistry: | |
def __init__(self, name: str, prompts: Optional[List[PromptTemplate]] = None, sync: bool = True): | |
if name is None: | |
raise ValueError("PromptRegistry must have a name.") | |
self.name: str = name | |
# Prompts initialized with the registry | |
# NOTE: These prompts cannot be updated | |
self.prompts: Dict[str, PromptTemplate] = {} | |
# Add prompts to prompts | |
if prompts: | |
for _prompt in prompts: | |
if _prompt.id is None: | |
raise ValueError("PromptTemplate cannot be added to Registry without an id.") | |
self.prompts[_prompt.id] = _prompt | |
# All prompts in the registry, including those synced from phidata | |
self.all_prompts: Dict[str, PromptTemplate] = {} | |
self.all_prompts.update(self.prompts) | |
# If the registry should sync with phidata | |
self._sync = sync | |
self._remote_registry: Optional[PromptRegistrySchema] = None | |
self._remote_templates: Optional[Dict[str, PromptTemplateSchema]] = None | |
# Sync the registry with phidata | |
if self._sync: | |
self.sync_registry() | |
logger.debug(f"Initialized prompt registry: {name}") | |
def get(self, id: str) -> Optional[PromptTemplate]: | |
logger.debug(f"Getting prompt: {id}") | |
return self.all_prompts.get(id, None) | |
def get_all(self) -> Dict[str, PromptTemplate]: | |
return self.all_prompts | |
def add(self, prompt: PromptTemplate): | |
prompt_id = prompt.id | |
if prompt_id is None: | |
raise ValueError("PromptTemplate cannot be added to Registry without an id.") | |
self.all_prompts[prompt_id] = prompt | |
if self._sync: | |
self._sync_template(prompt_id, prompt) | |
logger.debug(f"Added prompt: {prompt_id}") | |
def update(self, id: str, prompt: PromptTemplate, upsert: bool = True): | |
# Check if the prompt exists in the initial registry and should not be updated | |
if id in self.prompts: | |
raise PromptUpdateException(f"Prompt Id: {id} cannot be updated as it is initialized with the registry.") | |
# If upsert is False and the prompt is not found, raise an exception | |
if not upsert and id not in self.all_prompts: | |
raise PromptNotFoundException(f"Prompt Id: {id} not found in registry.") | |
# Update or insert the prompt | |
self.all_prompts[id] = prompt | |
# Sync the template if sync is enabled | |
if self._sync: | |
self._sync_template(id, prompt) | |
logger.debug(f"Updated prompt: {id}") | |
def sync_registry(self): | |
logger.debug(f"Syncing registry with phidata: {self.name}") | |
self._remote_registry, self._remote_templates = sync_prompt_registry_api( | |
registry=PromptRegistrySync(registry_name=self.name), | |
templates=PromptTemplatesSync( | |
templates={ | |
k: PromptTemplateSync(template_id=k, template_data=v.model_dump(exclude_none=True)) | |
for k, v in self.prompts.items() | |
} | |
), | |
) | |
if self._remote_templates is not None: | |
for k, v in self._remote_templates.items(): | |
self.all_prompts[k] = PromptTemplate.model_validate(v.template_data) | |
logger.debug(f"Synced registry with phidata: {self.name}") | |
def _sync_template(self, id: str, prompt: PromptTemplate): | |
logger.debug(f"Syncing template: {id} with registry: {self.name}") | |
# Determine if the template needs to be synced either because | |
# remote templates are not available, or | |
# template is not in remote templates, or | |
# the template_data has changed. | |
needs_sync = ( | |
self._remote_templates is None | |
or id not in self._remote_templates | |
or self._remote_templates[id].template_data != prompt.model_dump(exclude_none=True) | |
) | |
if needs_sync: | |
_prompt_template: Optional[PromptTemplateSchema] = sync_prompt_template_api( | |
registry=PromptRegistrySync(registry_name=self.name), | |
prompt_template=PromptTemplateSync(template_id=id, template_data=prompt.model_dump(exclude_none=True)), | |
) | |
if _prompt_template is not None: | |
if self._remote_templates is None: | |
self._remote_templates = {} | |
self._remote_templates[id] = _prompt_template | |
def __getitem__(self, id) -> Optional[PromptTemplate]: | |
return self.get(id) | |
def __str__(self): | |
return f"PromptRegistry: {self.name}" | |
def __repr__(self): | |
return f"PromptRegistry: {self.name}" | |