AurelioAguirre commited on
Commit
ab616bd
·
1 Parent(s): 0cbf645

Fixed issue with the server not managing request lifecycle correctly

Browse files
Files changed (1) hide show
  1. main/api.py +35 -24
main/api.py CHANGED
@@ -63,18 +63,26 @@ class InferenceApi(LitAPI):
63
  stream: bool = False
64
  ) -> Any:
65
  """Make an authenticated request to the LLM Server."""
 
 
 
66
  try:
67
- async with await self._get_client() as client:
68
- if stream:
69
- # Return the context manager directly, don't await it
70
- return client.stream(
71
- method,
72
- self._get_endpoint(endpoint),
73
- params=params,
74
- json=json
75
- )
76
- else:
77
- response = await client.request(
 
 
 
 
 
78
  method,
79
  self._get_endpoint(endpoint),
80
  params=params,
@@ -84,7 +92,7 @@ class InferenceApi(LitAPI):
84
  return response
85
 
86
  except Exception as e:
87
- self.logger.error(f"Error in request to {endpoint}: {str(e)}")
88
  raise
89
 
90
  def predict(self, x: str, **kwargs) -> Iterator[str]:
@@ -235,18 +243,21 @@ class InferenceApi(LitAPI):
235
  self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
236
 
237
  try:
238
- async with await self._make_request(
239
- "POST",
240
- "generate_stream",
241
- json={
242
- "prompt": prompt,
243
- "system_message": system_message,
244
- "max_new_tokens": max_new_tokens
245
- },
246
- stream=True
247
- ) as response:
248
- async for chunk in response.aiter_text():
249
- yield chunk
 
 
 
250
 
251
  except Exception as e:
252
  self.logger.error(f"Error in generate_stream: {str(e)}")
 
63
  stream: bool = False
64
  ) -> Any:
65
  """Make an authenticated request to the LLM Server."""
66
+ base_url = self.llm_config.get('host', 'http://localhost:8001')
67
+ full_endpoint = f"{base_url.rstrip('/')}/{self._get_endpoint(endpoint).lstrip('/')}"
68
+
69
  try:
70
+ self.logger.info(f"Making {method} request to: {full_endpoint}")
71
+ # Create client outside the with block for streaming
72
+ client = await self._get_client()
73
+
74
+ if stream:
75
+ # For streaming, return both client and response context managers
76
+ return client, client.stream(
77
+ method,
78
+ self._get_endpoint(endpoint),
79
+ params=params,
80
+ json=json
81
+ )
82
+ else:
83
+ # For non-streaming, use context manager
84
+ async with client as c:
85
+ response = await c.request(
86
  method,
87
  self._get_endpoint(endpoint),
88
  params=params,
 
92
  return response
93
 
94
  except Exception as e:
95
+ self.logger.error(f"Error in request to {full_endpoint}: {str(e)}")
96
  raise
97
 
98
  def predict(self, x: str, **kwargs) -> Iterator[str]:
 
243
  self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
244
 
245
  try:
246
+ client, stream_cm = await self._make_request(
247
+ "POST",
248
+ "generate_stream",
249
+ json={
250
+ "prompt": prompt,
251
+ "system_message": system_message,
252
+ "max_new_tokens": max_new_tokens
253
+ },
254
+ stream=True
255
+ )
256
+
257
+ async with client:
258
+ async with stream_cm as response:
259
+ async for chunk in response.aiter_text():
260
+ yield chunk
261
 
262
  except Exception as e:
263
  self.logger.error(f"Error in generate_stream: {str(e)}")