File size: 7,088 Bytes
32486dc
 
 
 
 
 
 
 
 
7d02ab7
 
 
 
32486dc
 
 
 
 
 
 
 
 
 
 
 
36c43cc
 
 
 
 
 
 
 
32486dc
 
36c43cc
32486dc
36c43cc
 
32486dc
36c43cc
 
32486dc
 
 
7d02ab7
 
 
32486dc
 
 
7d02ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32486dc
 
 
 
 
 
 
 
 
36c43cc
32486dc
 
 
 
 
 
 
36c43cc
32486dc
36c43cc
32486dc
 
 
36c43cc
32486dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d02ab7
 
32486dc
 
 
 
7d02ab7
 
 
 
 
 
 
 
32486dc
 
 
 
7d02ab7
 
 
 
 
 
 
32486dc
 
7d02ab7
 
32486dc
 
7d02ab7
32486dc
7d02ab7
32486dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os

class CodeEditor:
    def __init__(self, data) -> None:
        self.data = data
        # Initialize only if not already set to ensure it retains the original content
        if 'original_code_content' not in st.session_state:
            st.session_state.original_code_content = self.data["default_initial_solution"]

        self.llm_engine = tg.get_engine("gpt-4o")
        print("="*50, "init", "="*50)
        self.loss_value = ""
        self.code_gradients = ""
        if 'iteration' not in st.session_state:
            st.session_state.iteration = 0
        if 'results' not in st.session_state:
            st.session_state.results = []
        tg.set_backward_engine(self.llm_engine, override=True)
        

    def load_layout(self):
        # Initialize session state for problem description and other fields if not already set
        if 'problem' not in st.session_state:
            st.session_state.problem = self.data["default_problem_description"]
        if 'loss_system_prompt' not in st.session_state:
            st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"]
        if 'instruction' not in st.session_state:
            st.session_state.instruction = self.data["instruction"]

        col1, col2 = st.columns([1, 1])
        with col1:
            st.session_state.problem = st.text_area("Problem description:", st.session_state.problem, height=300)
        with col2:
            st.session_state.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=150)
            st.session_state.instruction = st.text_area("Instruction for formatted LLM call:", st.session_state.instruction, height=100)

        # Assume the code content also needs to be persistent
        if 'code_content' not in st.session_state:
            st.session_state.code_content = self.data["default_initial_solution"]

        def update_code_content(value):
            if st.session_state.iteration == 0:
                st.session_state.code_content = value
            # print(f"Code updated: {st.session_state.code_content}")

        col1, col2 = st.columns(2)
        with col1:
            with elements("monaco_editors_widget_original"):
                st.markdown(f"**Initial solution:**")
                # code = editor.Monaco(
                #     height=300,
                #     defaultLanguage="python",
                #     defaultValue=st.session_state.original_code_content,
                #     onChange=update_code_content,
                #     label="Initial Solution Viewer",
                # )

                code = st.text_area("Edit your code here:", st.session_state.original_code_content, height=300)
                # Update session state when text changes
                if code is not None and st.session_state.original_code_content != code:
                    update_code_content(code)

                # if st.session_state.code_content != code:
                #     update_code_content(code)
        # with col2:

    def _run(self):
        # Code is the variable of interest we want to optimize -- so requires_grad=True
        solution = st.session_state.code_content
        code = tg.Variable(value=solution,
                        requires_grad=True,
                        role_description="code instance to optimize")

        # We are not interested in optimizing the problem -- so requires_grad=False
        problem = tg.Variable(st.session_state.problem, 
                            requires_grad=False, 
                            role_description="the coding problem")

        # Let TGD know to update code!
        optimizer = tg.TGD(parameters=[code])


        instruction = st.session_state.instruction
        llm_engine = self.llm_engine
        loss_system_prompt = st.session_state.loss_system_prompt
        loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function")

        format_string = "{instruction}\nProblem: {{problem}}\nCurrent Code: {{code}}"
        format_string = format_string.format(instruction=st.session_state.instruction)

        fields = {"problem": None, "code": None}
        formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine,
                                                        format_string=format_string,
                                                        fields=fields,
                                                        system_prompt=loss_system_prompt)
        # Finally, the loss function
        def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable:
            inputs = {"problem": problem, "code": code}
            
            return formatted_llm_call(inputs=inputs,
                                    response_role_description=f"evaluation of the {code.get_role_description()}")
        loss = loss_fn(problem, code)
        self.loss_value = loss.value
        self.graph = loss.generate_graph()

        loss.backward()
        self.gradients = code.gradients
        optimizer.step() # Let's update the code

        st.session_state.code_content = code.value

    def show_results(self):
        self._run()
        st.session_state.iteration += 1
        st.session_state.results.append({
            'iteration': st.session_state.iteration,
            'loss_value': self.loss_value,
            'gradients': self.gradients,
            'code_content': st.session_state.code_content,
        })

        tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])

        # Include Highlight.js library and a theme CSS
        st.markdown("""
        <link rel="stylesheet"
            href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/styles/default.min.css">
        <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/highlight.min.js"></script>
        <script>hljs.highlightAll();</script>
        """, unsafe_allow_html=True)

        for i, tab in enumerate(tabs):
            with tab:
                result = st.session_state.results[i]
                st.markdown(f"Current iteration: **{result['iteration']}**")

                st.markdown("### Current solution")
                st.markdown(f"""
                    <pre><code class="language-python">{result["code_content"]}</code></pre>
                """, unsafe_allow_html=True)
                

                col1, col2 = st.columns([1, 1])
                with col1:
                    st.markdown("### Loss value")
                    st.markdown("**Loss value is based on previous code.**")
                    st.markdown(result['loss_value'])
                with col2:
                    st.markdown("### Code gradients")
                    for j, g in enumerate(result['gradients']):
                        # st.markdown(f"### Gradient {j}")
                        st.markdown(g.value)