File size: 2,686 Bytes
6369972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
PROMPT> python -m src.assume.assumption_orchestrator
"""
import logging
from llama_index.core.llms.llm import LLM
from src.assume.make_assumptions import MakeAssumptions
from src.assume.distill_assumptions import DistillAssumptions
from src.format_json_for_use_in_query import format_json_for_use_in_query

logger = logging.getLogger(__name__)

class AssumptionOrchestrator:
    def __init__(self):
        self.phase1_post_callback = None
        self.phase2_post_callback = None
        self.make_assumptions: MakeAssumptions = None
        self.distill_assumptions: DistillAssumptions = None

    def execute(self, llm: LLM, query: str) -> None:
        logger.info("Making assumptions...")

        self.make_assumptions = MakeAssumptions.execute(llm, query)
        if self.phase1_post_callback:
            self.phase1_post_callback(self.make_assumptions)

        logger.info(f"Distilling assumptions...")

        assumptions_json_string = format_json_for_use_in_query(self.make_assumptions.assumptions)

        query2 = (
            f"{query}\n\n"
            f"assumption.json:\n{assumptions_json_string}"
        )
        self.distill_assumptions = DistillAssumptions.execute(llm, query2)
        if self.phase2_post_callback:
            self.phase2_post_callback(self.distill_assumptions)
    
if __name__ == "__main__":
    import logging
    from src.llm_factory import get_llm
    from src.plan.find_plan_prompt import find_plan_prompt
    import json

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler()
        ]
    )

    plan_prompt = find_plan_prompt("4dc34d55-0d0d-4e9d-92f4-23765f49dd29")

    llm = get_llm("ollama-llama3.1")
    # llm = get_llm("openrouter-paid-gemini-2.0-flash-001")
    # llm = get_llm("deepseek-chat")

    def phase1_post_callback(make_assumptions: MakeAssumptions) -> None:
        count = len(make_assumptions.assumptions)
        d = make_assumptions.to_dict(include_system_prompt=False, include_user_prompt=False)
        pretty = json.dumps(d, indent=2)
        print(f"MakeAssumptions: Made {count} assumptions:\n{pretty}")

    def phase2_post_callback(distill_assumptions: DistillAssumptions) -> None:
        d = distill_assumptions.to_dict(include_system_prompt=False, include_user_prompt=False)
        pretty = json.dumps(d, indent=2)
        print(f"DistillAssumptions:\n{pretty}")

    orchestrator = AssumptionOrchestrator()
    orchestrator.phase1_post_callback = phase1_post_callback
    orchestrator.phase2_post_callback = phase2_post_callback
    orchestrator.execute(llm, plan_prompt)