AmmarFahmy
adding all files
105b369
from typing import List, Dict, Optional
from phi.api.prompt import sync_prompt_registry_api, sync_prompt_template_api
from phi.api.schemas.prompt import (
PromptRegistrySync,
PromptTemplatesSync,
PromptTemplateSync,
PromptRegistrySchema,
PromptTemplateSchema,
)
from phi.prompt.template import PromptTemplate
from phi.prompt.exceptions import PromptUpdateException, PromptNotFoundException
from phi.utils.log import logger
class PromptRegistry:
def __init__(self, name: str, prompts: Optional[List[PromptTemplate]] = None, sync: bool = True):
if name is None:
raise ValueError("PromptRegistry must have a name.")
self.name: str = name
# Prompts initialized with the registry
# NOTE: These prompts cannot be updated
self.prompts: Dict[str, PromptTemplate] = {}
# Add prompts to prompts
if prompts:
for _prompt in prompts:
if _prompt.id is None:
raise ValueError("PromptTemplate cannot be added to Registry without an id.")
self.prompts[_prompt.id] = _prompt
# All prompts in the registry, including those synced from phidata
self.all_prompts: Dict[str, PromptTemplate] = {}
self.all_prompts.update(self.prompts)
# If the registry should sync with phidata
self._sync = sync
self._remote_registry: Optional[PromptRegistrySchema] = None
self._remote_templates: Optional[Dict[str, PromptTemplateSchema]] = None
# Sync the registry with phidata
if self._sync:
self.sync_registry()
logger.debug(f"Initialized prompt registry: {name}")
def get(self, id: str) -> Optional[PromptTemplate]:
logger.debug(f"Getting prompt: {id}")
return self.all_prompts.get(id, None)
def get_all(self) -> Dict[str, PromptTemplate]:
return self.all_prompts
def add(self, prompt: PromptTemplate):
prompt_id = prompt.id
if prompt_id is None:
raise ValueError("PromptTemplate cannot be added to Registry without an id.")
self.all_prompts[prompt_id] = prompt
if self._sync:
self._sync_template(prompt_id, prompt)
logger.debug(f"Added prompt: {prompt_id}")
def update(self, id: str, prompt: PromptTemplate, upsert: bool = True):
# Check if the prompt exists in the initial registry and should not be updated
if id in self.prompts:
raise PromptUpdateException(f"Prompt Id: {id} cannot be updated as it is initialized with the registry.")
# If upsert is False and the prompt is not found, raise an exception
if not upsert and id not in self.all_prompts:
raise PromptNotFoundException(f"Prompt Id: {id} not found in registry.")
# Update or insert the prompt
self.all_prompts[id] = prompt
# Sync the template if sync is enabled
if self._sync:
self._sync_template(id, prompt)
logger.debug(f"Updated prompt: {id}")
def sync_registry(self):
logger.debug(f"Syncing registry with phidata: {self.name}")
self._remote_registry, self._remote_templates = sync_prompt_registry_api(
registry=PromptRegistrySync(registry_name=self.name),
templates=PromptTemplatesSync(
templates={
k: PromptTemplateSync(template_id=k, template_data=v.model_dump(exclude_none=True))
for k, v in self.prompts.items()
}
),
)
if self._remote_templates is not None:
for k, v in self._remote_templates.items():
self.all_prompts[k] = PromptTemplate.model_validate(v.template_data)
logger.debug(f"Synced registry with phidata: {self.name}")
def _sync_template(self, id: str, prompt: PromptTemplate):
logger.debug(f"Syncing template: {id} with registry: {self.name}")
# Determine if the template needs to be synced either because
# remote templates are not available, or
# template is not in remote templates, or
# the template_data has changed.
needs_sync = (
self._remote_templates is None
or id not in self._remote_templates
or self._remote_templates[id].template_data != prompt.model_dump(exclude_none=True)
)
if needs_sync:
_prompt_template: Optional[PromptTemplateSchema] = sync_prompt_template_api(
registry=PromptRegistrySync(registry_name=self.name),
prompt_template=PromptTemplateSync(template_id=id, template_data=prompt.model_dump(exclude_none=True)),
)
if _prompt_template is not None:
if self._remote_templates is None:
self._remote_templates = {}
self._remote_templates[id] = _prompt_template
def __getitem__(self, id) -> Optional[PromptTemplate]:
return self.get(id)
def __str__(self):
return f"PromptRegistry: {self.name}"
def __repr__(self):
return f"PromptRegistry: {self.name}"