qq8933 commited on
Commit
1e55599
1 Parent(s): 06d5b88

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -1
README.md CHANGED
@@ -105,10 +105,60 @@ async def get_prediction(input_request: InputRequest):
105
  raise HTTPException(status_code=500, detail=str(e))
106
 
107
  ```
108
-
109
  ```
110
  uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
111
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  ## Training procedure
114
 
 
105
  raise HTTPException(status_code=500, detail=str(e))
106
 
107
  ```
108
+ run pprm_server
109
  ```
110
  uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
111
  ```
112
+ request pprm server
113
+ ```
114
+ # qeustion,answer_1,answer_2 = 'What is the capital of France?', 'Berlin', 'Paris'
115
+ # {'yes_logit': -24.26136016845703, 'no_logit': 19.517587661743164, 'logit_difference': -43.778947830200195}
116
+ # Is answer_1 better than answer_2? yes or no
117
+ # 奖励模型的入口
118
+ def request_prediction(
119
+ qeustion, answer_1, answer_2, url="http://10.140.24.56:10085/predict"
120
+ ):
121
+ """
122
+ Sends a POST request to the FastAPI server to get a prediction.
123
+
124
+ Args:
125
+ - text (str): The input text for the prediction.
126
+ - url (str): The API endpoint URL. Defaults to 'http://localhost:8000/predict'.
127
+
128
+ Returns:
129
+ - dict: The response from the API containing prediction results.
130
+ """
131
+ headers = {"Content-Type": "application/json"}
132
+ payload = {
133
+ "text": json.dumps(
134
+ {"qeustion": qeustion, "answer_1": answer_1, "answer_2": answer_2}
135
+ )
136
+ }
137
+
138
+ response = requests.post(url, json=payload, headers=headers, timeout=TIMEOUT_PRM)
139
+ response.raise_for_status() # Raises an HTTPError if the response code was unsuccessful
140
+ return response.json() # Return the JSON response as a dictionary
141
+
142
+ def cal_reward(question, ans, ans2="I don't know"):
143
+ if ans2 in DUMMY_ANSWERS:#I don't know
144
+ return 1
145
+ if ans in DUMMY_ANSWERS:
146
+ return 0
147
+ urls = copy.deepcopy(prm_servers)
148
+ random.shuffle(urls)
149
+ for url in urls:
150
+ try:
151
+ response = request_prediction(question, ans, ans2, url)
152
+ return math.exp(response["yes_logit"]) / (
153
+ math.exp(response["yes_logit"]) + math.exp(response["no_logit"])
154
+ )
155
+ except Exception as e:
156
+ # print(e)
157
+ continue
158
+ print(Exception("All prm servers are down"))
159
+ # get_clients()
160
+ return cal_reward(question, ans, ans2)
161
+ ```
162
 
163
  ## Training procedure
164