AurelioAguirre commited on
Commit
63fdbaa
·
1 Parent(s): ff91b48

fixing pydantic issue v6

Browse files
Files changed (1) hide show
  1. main/api.py +30 -9
main/api.py CHANGED
@@ -1,6 +1,7 @@
1
  import httpx
2
- from typing import Optional, AsyncIterator, Dict, Any
3
  import logging
 
4
  from litserve import LitAPI
5
  from pydantic import BaseModel
6
 
@@ -24,10 +25,25 @@ class InferenceApi(LitAPI):
24
  )
25
  self.logger.info(f"Inference API setup completed on device: {device}")
26
 
27
- async def predict(self, x: str, **kwargs) -> AsyncIterator[str]:
28
  """
29
- Main prediction method required by LitAPI.
30
- Always yields, either chunks in streaming mode or complete response in non-streaming mode.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
  if self.stream:
33
  async for chunk in self.generate_stream(x, **kwargs):
@@ -42,12 +58,17 @@ class InferenceApi(LitAPI):
42
  return request["prompt"]
43
  return request
44
 
45
- def encode_response(self, output: AsyncIterator[str], **kwargs) -> AsyncIterator[Dict[str, str]]:
46
  """Convert the model output to a response payload."""
47
- async def wrapper():
48
- async for chunk in output:
49
- yield {"generated_text": chunk}
50
- return wrapper()
 
 
 
 
 
51
 
52
  async def generate_response(
53
  self,
 
1
  import httpx
2
+ from typing import Optional, AsyncIterator, Dict, Any, Iterator
3
  import logging
4
+ import asyncio
5
  from litserve import LitAPI
6
  from pydantic import BaseModel
7
 
 
25
  )
26
  self.logger.info(f"Inference API setup completed on device: {device}")
27
 
28
+ def predict(self, x: str, **kwargs) -> Iterator[str]:
29
  """
30
+ Non-async prediction method that yields results.
31
+ """
32
+ loop = asyncio.get_event_loop()
33
+ async def async_gen():
34
+ async for item in self._async_predict(x, **kwargs):
35
+ yield item
36
+
37
+ gen = async_gen()
38
+ while True:
39
+ try:
40
+ yield loop.run_until_complete(gen.__anext__())
41
+ except StopAsyncIteration:
42
+ break
43
+
44
+ async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]:
45
+ """
46
+ Internal async prediction method.
47
  """
48
  if self.stream:
49
  async for chunk in self.generate_stream(x, **kwargs):
 
58
  return request["prompt"]
59
  return request
60
 
61
+ def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]:
62
  """Convert the model output to a response payload."""
63
+ # For streaming responses
64
+ if self.stream:
65
+ return {"generated_text": output}
66
+ # For non-streaming, take the first (and only) item from the iterator
67
+ try:
68
+ result = next(output)
69
+ return {"generated_text": result}
70
+ except StopIteration:
71
+ return {"generated_text": ""}
72
 
73
  async def generate_response(
74
  self,