ykl7 commited on
Commit
9a31c8f
·
1 Parent(s): 4c814ad

add reasoner

Browse files
Files changed (3) hide show
  1. app.py +45 -5
  2. llm_reasoner.py +105 -0
  3. prompts.py +18 -0
app.py CHANGED
@@ -3,10 +3,20 @@ import random
3
  import time
4
  import hmac
5
  import os
 
 
6
 
7
  st.header(" Scientific Claim Verification ")
8
  st.caption("Team UMBC-SBU-UT")
9
 
 
 
 
 
 
 
 
 
10
  def check_password():
11
  """Returns `True` if the user had a correct password."""
12
 
@@ -114,7 +124,7 @@ for message in st.session_state.messages:
114
  with st.chat_message(message["role"]):
115
  st.markdown(message["content"])
116
 
117
- def retriever(query: str):
118
  """Simulate a 'retriever' step, searching for relevant information."""
119
  with st.chat_message("assistant"):
120
  placeholder = st.empty()
@@ -137,7 +147,7 @@ def retriever(query: str):
137
  # You could return retrieved info here.
138
  return message
139
 
140
- def reasoner(info: list[str]):
141
  """Simulate a 'reasoner' step, thinking about how to answer."""
142
  with st.chat_message("assistant"):
143
  placeholder = st.empty()
@@ -150,6 +160,17 @@ def reasoner(info: list[str]):
150
  else:
151
  message = "Using o3-mini to quickly analyze the claim..."
152
 
 
 
 
 
 
 
 
 
 
 
 
153
  for chunk in message.split():
154
  text += chunk + " "
155
  time.sleep(0.05)
@@ -162,15 +183,34 @@ def reasoner(info: list[str]):
162
  # Accept user input
163
  if prompt := st.chat_input("Type here"):
164
  # Add user message to chat history
165
- prompt= prompt + " \n"+ " \n"+ f"Retriever: {selected_retriever}, Reasoner: {selected_reasoner}"
166
  st.session_state.messages.append({"role": "user", "content": prompt})
167
  # Display user message in chat message container
168
  with st.chat_message("user"):
169
  st.markdown(prompt)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- retrieved_documents=retriever(prompt)
173
- reasoning = reasoner(retrieved_documents)
174
 
175
  # Display assistant response in chat message container
176
  with st.chat_message("assistant"):
 
3
  import time
4
  import hmac
5
  import os
6
+ from llm_reasoner import LLMReasoner
7
+ from prompts import templates
8
 
9
  st.header(" Scientific Claim Verification ")
10
  st.caption("Team UMBC-SBU-UT")
11
 
12
+ def safe_parse_json(model_answer):
13
+ """.."""
14
+ try:
15
+ return json.loads(model_answer)
16
+ except json.JSONDecodeError as e:
17
+ logger.error("Failed to parse JSON: %s", e)
18
+ return None
19
+
20
  def check_password():
21
  """Returns `True` if the user had a correct password."""
22
 
 
124
  with st.chat_message(message["role"]):
125
  st.markdown(message["content"])
126
 
127
+ def retriever(query: str, selected_retriever: str):
128
  """Simulate a 'retriever' step, searching for relevant information."""
129
  with st.chat_message("assistant"):
130
  placeholder = st.empty()
 
147
  # You could return retrieved info here.
148
  return message
149
 
150
+ def reasoner(query: str, documents: list[str], llm_client: Any):
151
  """Simulate a 'reasoner' step, thinking about how to answer."""
152
  with st.chat_message("assistant"):
153
  placeholder = st.empty()
 
160
  else:
161
  message = "Using o3-mini to quickly analyze the claim..."
162
 
163
+ if not documents:
164
+ prompt = templates["no_evidence"].format(claim=query)
165
+ llm_response = llm_client.run_inference(prompt)
166
+
167
+ message = message + '\n' + llm_response
168
+
169
+ answer_dict = safe_parse_json(llm_response)
170
+ decision = answer_dict.get("decision", "")
171
+
172
+ message = message + '\n' + decision
173
+
174
  for chunk in message.split():
175
  text += chunk + " "
176
  time.sleep(0.05)
 
183
  # Accept user input
184
  if prompt := st.chat_input("Type here"):
185
  # Add user message to chat history
186
+ prompt = prompt + " \n"+ " \n"+ f"Retriever: {selected_retriever}, Reasoner: {selected_reasoner}"
187
  st.session_state.messages.append({"role": "user", "content": prompt})
188
  # Display user message in chat message container
189
  with st.chat_message("user"):
190
  st.markdown(prompt)
191
 
192
+ options = {}
193
+ options["max_tokens"] = 500
194
+ options["temperature"] = 0.0
195
+
196
+ if selected_reasoner == "Claude Sonnet":
197
+ api_key = os.getenv(st.session_state["claude_key"])
198
+ options["API_KEY"] = api_key
199
+ options["model_family"] = "Anthropic"
200
+ options["model_name"] = "claude-3-5-sonnet-20240620"
201
+
202
+ elif selected_reasoner == "GPT-4o":
203
+ api_key = os.getenv(st.session_state["openai_key"])
204
+ options["API_KEY"] = api_key
205
+ options["model_family"] = "OpenAI"
206
+ options["model_name"] = "gpt-4o-2024-05-13"
207
+
208
+
209
+ llm_client = LLMReasoner(api_key)
210
+
211
 
212
+ retrieved_documents = retriever(prompt)
213
+ reasoning = reasoner(prompt, retrieved_documents, llm_client)
214
 
215
  # Display assistant response in chat message container
216
  with st.chat_message("assistant"):
llm_reasoner.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ __author__ = "Yash Kumar Lal, Github@ykl7"
3
+
4
+ import os
5
+ import openai
6
+ from openai import OpenAI
7
+ import anthropic
8
+ import time
9
+ import random
10
+ from config import config
11
+
12
+ random.seed(1234)
13
+
14
+ class LLMReasoner():
15
+
16
+ def __init__(self, options):
17
+
18
+ if options["model_family"] == "OpenAI":
19
+ self.client = OpenAI(api_key=options["API_KEY"])
20
+ elif options["model_family"] == "Anthropic":
21
+ os.environ["ANTHROPIC_API_KEY"] = options["API_KEY"]
22
+ self.client = anthropic.Anthropic()
23
+
24
+ self.model_family = options["model_family"]
25
+ self.model_name = options["model_name"]
26
+ self.max_tokens = options["max_tokens"]
27
+ self.temp = 0.0 if "temperature" not in options else options["temperature"]
28
+ self.top_p = 1.0 if "top_p" not in options else options["top_p"]
29
+ self.frequency_penalty = 0.0 if "frequency_penalty" not in options else options["frequency_penalty"]
30
+ self.presence_penalty = 0.0 if "presence_penalty" not in options else options["presence_penalty"]
31
+
32
+ def make_openai_chat_completions_api_call(self, prompt):
33
+ try:
34
+ response = self.client.chat.completions.create(
35
+ model=self.model_name,
36
+ messages=prompt,
37
+ temperature=self.temp,
38
+ #max_completion_tokens=self.max_tokens,
39
+ top_p=self.top_p,
40
+ frequency_penalty=self.frequency_penalty,
41
+ presence_penalty=self.presence_penalty
42
+ )
43
+ return self.parse_chat_completions_api_response(response)
44
+ except openai.APIConnectionError as e:
45
+ print("The server could not be reached")
46
+ print(e.__cause__) # an underlying Exception, likely raised within httpx.
47
+ time.sleep(60)
48
+ return self.make_openai_api_call(prompt)
49
+ except openai.RateLimitError as e:
50
+ print("Rate limit error hit")
51
+ exit()
52
+ except openai.NotFoundError as e:
53
+ print("Model not found")
54
+ exit()
55
+ except openai.APIStatusError as e:
56
+ print("Another non-200-range status code was received")
57
+ print(e.status_code)
58
+ print(e.response.data)
59
+ time.sleep(60)
60
+ return self.make_openai_api_call(prompt)
61
+
62
+ def parse_chat_completions_api_response(self, response):
63
+ # print(response.model_dump())
64
+ choices = response.choices
65
+ main_response = choices[0].message
66
+ main_response_message, main_response_role = main_response.content, main_response.role
67
+ return main_response_message, response
68
+
69
+ def call_claude(self, claude_prompt=""):
70
+ try:
71
+ message = self.client.messages.create(
72
+ model=self.model_name,
73
+ max_tokens=self.max_tokens,
74
+ temperature=self.temp,
75
+ system="",
76
+ messages=[
77
+ {
78
+ "role": "user",
79
+ "content": [
80
+ {
81
+ "type": "text",
82
+ "text": claude_prompt
83
+ }
84
+ ]
85
+ }
86
+ ]
87
+ )
88
+ except Exception as e:
89
+ breakpoint()
90
+ print(e)
91
+ time.sleep(30)
92
+ call_claude(self, claude_prompt)
93
+ if message.content[0].type == "text":
94
+ return message.content[0].text, message
95
+ else:
96
+ return "Error", message
97
+
98
+ def run_inference(self, prompt=[]):
99
+
100
+ if self.model_family == "OpenAI":
101
+ response_text, response = self.make_openai_chat_completions_api_call(prompt)
102
+ elif self.model_family == "Anthropic":
103
+ response_text, response = self.call_claude(prompt)
104
+
105
+ return response_text
prompts.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ templates = {
2
+ "no_evidence": """
3
+ You are an AI model tasked with verifying claims related to medical and health topics using zero-shot learning. Your job is to analyze a given claim and decide whether the available evidence and your general medical knowledge would likely SUPPORT or CONTRADICT the claim.
4
+ Claim to Evaluate: <claim> {claim} </claim>
5
+ Guidelines:
6
+ Evaluate the claim's plausibility based on general medical knowledge.
7
+ Consider the specificity and credibility of any numbers or percentages.
8
+ Analyze the context and scope of the claim.
9
+ Assess any potential biases or limitations.
10
+ Output Format:
11
+ After your analysis, output exactly one JSON object with two keys:
12
+ \"reasoning\": A brief explanation (one or two sentences).
13
+ \"decision\": Either \"SUPPORT\" or \"CONTRADICT\" (uppercase, no additional text).
14
+ Do not add markdown formatting, code fences, or additional text. The output must start with { and end with }.
15
+ Example Output: {\"reasoning\": \"Your brief explanation here (one or two sentences).\", \"decision\": \"SUPPORT or CONTRADICT\"}
16
+ Now, please evaluate the claim above.
17
+ """
18
+ }