File size: 5,140 Bytes
105b369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List, Dict, Optional

from phi.api.prompt import sync_prompt_registry_api, sync_prompt_template_api
from phi.api.schemas.prompt import (
    PromptRegistrySync,
    PromptTemplatesSync,
    PromptTemplateSync,
    PromptRegistrySchema,
    PromptTemplateSchema,
)
from phi.prompt.template import PromptTemplate
from phi.prompt.exceptions import PromptUpdateException, PromptNotFoundException
from phi.utils.log import logger


class PromptRegistry:
    def __init__(self, name: str, prompts: Optional[List[PromptTemplate]] = None, sync: bool = True):
        if name is None:
            raise ValueError("PromptRegistry must have a name.")

        self.name: str = name
        # Prompts initialized with the registry
        # NOTE: These prompts cannot be updated
        self.prompts: Dict[str, PromptTemplate] = {}
        # Add prompts to prompts
        if prompts:
            for _prompt in prompts:
                if _prompt.id is None:
                    raise ValueError("PromptTemplate cannot be added to Registry without an id.")
                self.prompts[_prompt.id] = _prompt

        # All prompts in the registry, including those synced from phidata
        self.all_prompts: Dict[str, PromptTemplate] = {}
        self.all_prompts.update(self.prompts)

        # If the registry should sync with phidata
        self._sync = sync
        self._remote_registry: Optional[PromptRegistrySchema] = None
        self._remote_templates: Optional[Dict[str, PromptTemplateSchema]] = None
        # Sync the registry with phidata
        if self._sync:
            self.sync_registry()
        logger.debug(f"Initialized prompt registry: {name}")

    def get(self, id: str) -> Optional[PromptTemplate]:
        logger.debug(f"Getting prompt: {id}")
        return self.all_prompts.get(id, None)

    def get_all(self) -> Dict[str, PromptTemplate]:
        return self.all_prompts

    def add(self, prompt: PromptTemplate):
        prompt_id = prompt.id
        if prompt_id is None:
            raise ValueError("PromptTemplate cannot be added to Registry without an id.")

        self.all_prompts[prompt_id] = prompt
        if self._sync:
            self._sync_template(prompt_id, prompt)
        logger.debug(f"Added prompt: {prompt_id}")

    def update(self, id: str, prompt: PromptTemplate, upsert: bool = True):
        # Check if the prompt exists in the initial registry and should not be updated
        if id in self.prompts:
            raise PromptUpdateException(f"Prompt Id: {id} cannot be updated as it is initialized with the registry.")
        # If upsert is False and the prompt is not found, raise an exception
        if not upsert and id not in self.all_prompts:
            raise PromptNotFoundException(f"Prompt Id: {id} not found in registry.")
        # Update or insert the prompt
        self.all_prompts[id] = prompt
        # Sync the template if sync is enabled
        if self._sync:
            self._sync_template(id, prompt)
        logger.debug(f"Updated prompt: {id}")

    def sync_registry(self):
        logger.debug(f"Syncing registry with phidata: {self.name}")
        self._remote_registry, self._remote_templates = sync_prompt_registry_api(
            registry=PromptRegistrySync(registry_name=self.name),
            templates=PromptTemplatesSync(
                templates={
                    k: PromptTemplateSync(template_id=k, template_data=v.model_dump(exclude_none=True))
                    for k, v in self.prompts.items()
                }
            ),
        )
        if self._remote_templates is not None:
            for k, v in self._remote_templates.items():
                self.all_prompts[k] = PromptTemplate.model_validate(v.template_data)
        logger.debug(f"Synced registry with phidata: {self.name}")

    def _sync_template(self, id: str, prompt: PromptTemplate):
        logger.debug(f"Syncing template: {id} with registry: {self.name}")

        # Determine if the template needs to be synced either because
        # remote templates are not available, or
        # template is not in remote templates, or
        # the template_data has changed.
        needs_sync = (
            self._remote_templates is None
            or id not in self._remote_templates
            or self._remote_templates[id].template_data != prompt.model_dump(exclude_none=True)
        )

        if needs_sync:
            _prompt_template: Optional[PromptTemplateSchema] = sync_prompt_template_api(
                registry=PromptRegistrySync(registry_name=self.name),
                prompt_template=PromptTemplateSync(template_id=id, template_data=prompt.model_dump(exclude_none=True)),
            )
            if _prompt_template is not None:
                if self._remote_templates is None:
                    self._remote_templates = {}
                self._remote_templates[id] = _prompt_template

    def __getitem__(self, id) -> Optional[PromptTemplate]:
        return self.get(id)

    def __str__(self):
        return f"PromptRegistry: {self.name}"

    def __repr__(self):
        return f"PromptRegistry: {self.name}"