David Chu commited on
Commit
b868906
·
unverified ·
1 Parent(s): 8dd0d65

feat: add thought summary in the response

Browse files
Files changed (3) hide show
  1. app/agent.py +13 -1
  2. app/models.py +1 -0
  3. main.py +7 -5
app/agent.py CHANGED
@@ -15,6 +15,7 @@ CONFIG = types.GenerateContentConfig(
15
  literature.search_medical_literature,
16
  ],
17
  system_instruction=(Path(__file__).parent / "system_instruction.txt").read_text(),
 
18
  )
19
 
20
  SOURCE_TOOL_NAMES = {
@@ -48,6 +49,13 @@ def hydrate_sources(
48
 
49
 
50
  def validate_response(response: types.GenerateContentResponse) -> models.Statements:
 
 
 
 
 
 
 
51
  text = (response.text or "").strip()
52
 
53
  # Extract content inside the first markdown code block (``` or ```json)
@@ -57,8 +65,12 @@ def validate_response(response: types.GenerateContentResponse) -> models.Stateme
57
 
58
  try:
59
  statements = models.Statements.model_validate_json(f'{{"statements":{text}}}')
 
60
  except ValidationError:
61
- statements = models.Statements(statements=[models.Statement(text=text)])
 
 
 
62
 
63
  statements = hydrate_sources(
64
  statements, response.automatic_function_calling_history or []
 
15
  literature.search_medical_literature,
16
  ],
17
  system_instruction=(Path(__file__).parent / "system_instruction.txt").read_text(),
18
+ thinking_config=types.ThinkingConfig(include_thoughts=True),
19
  )
20
 
21
  SOURCE_TOOL_NAMES = {
 
49
 
50
 
51
  def validate_response(response: types.GenerateContentResponse) -> models.Statements:
52
+ thoughts = []
53
+
54
+ for part in response.candidates[0].content.parts: # type: ignore
55
+ if part.thought:
56
+ thoughts.append(part.text)
57
+
58
+ thoughts = " ".join(thoughts)
59
  text = (response.text or "").strip()
60
 
61
  # Extract content inside the first markdown code block (``` or ```json)
 
65
 
66
  try:
67
  statements = models.Statements.model_validate_json(f'{{"statements":{text}}}')
68
+ statements.thoughts = thoughts
69
  except ValidationError:
70
+ statements = models.Statements(
71
+ statements=[models.Statement(text=text)],
72
+ thoughts=thoughts,
73
+ )
74
 
75
  statements = hydrate_sources(
76
  statements, response.automatic_function_calling_history or []
app/models.py CHANGED
@@ -25,3 +25,4 @@ class Statement(BaseModel):
25
 
26
  class Statements(BaseModel):
27
  statements: list[Statement]
 
 
25
 
26
  class Statements(BaseModel):
27
  statements: list[Statement]
28
+ thoughts: str | None = None
main.py CHANGED
@@ -28,7 +28,9 @@ def format_output(statements: models.Statements) -> tuple[str, str]:
28
 
29
  answer = " ".join(sentences)
30
  footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
31
- return answer, footnotes
 
 
32
 
33
 
34
  def main():
@@ -38,14 +40,14 @@ def main():
38
  with st.form("search", border=False):
39
  query = st.text_input("Your medical question")
40
  submit = st.form_submit_button("Ask")
41
- response = st.empty()
42
 
43
  if submit:
44
  with st.spinner("Thinking...", show_time=True):
45
  output = agent.respond(gemini_client, query)
46
-
47
- answer, footnotes = format_output(output)
48
- response.markdown(f"{answer}\n\n{footnotes}")
 
49
 
50
 
51
  if __name__ == "__main__":
 
28
 
29
  answer = " ".join(sentences)
30
  footnotes = "\n".join(f"[^{id}]: {citation}" for citation, id in citations.items())
31
+ thought = statements.thoughts or ""
32
+
33
+ return f"{answer}\n\n{footnotes}", thought
34
 
35
 
36
  def main():
 
40
  with st.form("search", border=False):
41
  query = st.text_input("Your medical question")
42
  submit = st.form_submit_button("Ask")
 
43
 
44
  if submit:
45
  with st.spinner("Thinking...", show_time=True):
46
  output = agent.respond(gemini_client, query)
47
+ answer, thoughts = format_output(output)
48
+ with st.expander("Thinking Process"):
49
+ st.markdown(thoughts)
50
+ st.markdown(answer)
51
 
52
 
53
  if __name__ == "__main__":