File size: 7,088 Bytes
32486dc 7d02ab7 32486dc 36c43cc 32486dc 36c43cc 32486dc 36c43cc 32486dc 36c43cc 32486dc 7d02ab7 32486dc 7d02ab7 32486dc 36c43cc 32486dc 36c43cc 32486dc 36c43cc 32486dc 36c43cc 32486dc 7d02ab7 32486dc 7d02ab7 32486dc 7d02ab7 32486dc 7d02ab7 32486dc 7d02ab7 32486dc 7d02ab7 32486dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import streamlit as st
from streamlit_elements import elements, mui, editor, dashboard
from stqdm import stqdm
import textgrad as tg
import os
class CodeEditor:
def __init__(self, data) -> None:
self.data = data
# Initialize only if not already set to ensure it retains the original content
if 'original_code_content' not in st.session_state:
st.session_state.original_code_content = self.data["default_initial_solution"]
self.llm_engine = tg.get_engine("gpt-4o")
print("="*50, "init", "="*50)
self.loss_value = ""
self.code_gradients = ""
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
if 'results' not in st.session_state:
st.session_state.results = []
tg.set_backward_engine(self.llm_engine, override=True)
def load_layout(self):
# Initialize session state for problem description and other fields if not already set
if 'problem' not in st.session_state:
st.session_state.problem = self.data["default_problem_description"]
if 'loss_system_prompt' not in st.session_state:
st.session_state.loss_system_prompt = self.data["default_loss_system_prompt"]
if 'instruction' not in st.session_state:
st.session_state.instruction = self.data["instruction"]
col1, col2 = st.columns([1, 1])
with col1:
st.session_state.problem = st.text_area("Problem description:", st.session_state.problem, height=300)
with col2:
st.session_state.loss_system_prompt = st.text_area("Loss system prompt:", st.session_state.loss_system_prompt, height=150)
st.session_state.instruction = st.text_area("Instruction for formatted LLM call:", st.session_state.instruction, height=100)
# Assume the code content also needs to be persistent
if 'code_content' not in st.session_state:
st.session_state.code_content = self.data["default_initial_solution"]
def update_code_content(value):
if st.session_state.iteration == 0:
st.session_state.code_content = value
# print(f"Code updated: {st.session_state.code_content}")
col1, col2 = st.columns(2)
with col1:
with elements("monaco_editors_widget_original"):
st.markdown(f"**Initial solution:**")
# code = editor.Monaco(
# height=300,
# defaultLanguage="python",
# defaultValue=st.session_state.original_code_content,
# onChange=update_code_content,
# label="Initial Solution Viewer",
# )
code = st.text_area("Edit your code here:", st.session_state.original_code_content, height=300)
# Update session state when text changes
if code is not None and st.session_state.original_code_content != code:
update_code_content(code)
# if st.session_state.code_content != code:
# update_code_content(code)
# with col2:
def _run(self):
# Code is the variable of interest we want to optimize -- so requires_grad=True
solution = st.session_state.code_content
code = tg.Variable(value=solution,
requires_grad=True,
role_description="code instance to optimize")
# We are not interested in optimizing the problem -- so requires_grad=False
problem = tg.Variable(st.session_state.problem,
requires_grad=False,
role_description="the coding problem")
# Let TGD know to update code!
optimizer = tg.TGD(parameters=[code])
instruction = st.session_state.instruction
llm_engine = self.llm_engine
loss_system_prompt = st.session_state.loss_system_prompt
loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function")
format_string = "{instruction}\nProblem: {{problem}}\nCurrent Code: {{code}}"
format_string = format_string.format(instruction=st.session_state.instruction)
fields = {"problem": None, "code": None}
formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine,
format_string=format_string,
fields=fields,
system_prompt=loss_system_prompt)
# Finally, the loss function
def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable:
inputs = {"problem": problem, "code": code}
return formatted_llm_call(inputs=inputs,
response_role_description=f"evaluation of the {code.get_role_description()}")
loss = loss_fn(problem, code)
self.loss_value = loss.value
self.graph = loss.generate_graph()
loss.backward()
self.gradients = code.gradients
optimizer.step() # Let's update the code
st.session_state.code_content = code.value
def show_results(self):
self._run()
st.session_state.iteration += 1
st.session_state.results.append({
'iteration': st.session_state.iteration,
'loss_value': self.loss_value,
'gradients': self.gradients,
'code_content': st.session_state.code_content,
})
tabs = st.tabs([f"Iteration {i+1}" for i in range(st.session_state.iteration)])
# Include Highlight.js library and a theme CSS
st.markdown("""
<link rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.5.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
""", unsafe_allow_html=True)
for i, tab in enumerate(tabs):
with tab:
result = st.session_state.results[i]
st.markdown(f"Current iteration: **{result['iteration']}**")
st.markdown("### Current solution")
st.markdown(f"""
<pre><code class="language-python">{result["code_content"]}</code></pre>
""", unsafe_allow_html=True)
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### Loss value")
st.markdown("**Loss value is based on previous code.**")
st.markdown(result['loss_value'])
with col2:
st.markdown("### Code gradients")
for j, g in enumerate(result['gradients']):
# st.markdown(f"### Gradient {j}")
st.markdown(g.value) |