File size: 3,998 Bytes
9b73fc2
 
47031d7
db5664e
47031d7
a4e24d4
db5664e
 
 
 
47031d7
a4e24d4
db5664e
47031d7
a4e24d4
47031d7
 
db5664e
a4e24d4
 
 
db5664e
47031d7
db5664e
 
47031d7
a4e24d4
47031d7
db5664e
8f2f662
 
db5664e
8f2f662
 
 
 
 
 
 
a4e24d4
 
 
 
 
 
 
db5664e
a4e24d4
db5664e
 
 
 
47031d7
 
 
 
 
 
 
a4e24d4
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db5664e
a4e24d4
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e24d4
 
 
db5664e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# api.py file in main directory of the Inference API module.

import httpx
from typing import Optional, AsyncIterator, Dict, Any
import logging
from litserve import LitAPI
from pydantic import BaseModel

class GenerationResponse(BaseModel):
    generated_text: str

class InferenceApi(LitAPI):
    def __init__(self):
        """Initialize the Inference API with configuration."""
        super().__init__()
        self.logger = logging.getLogger(__name__)
        self.logger.info("Initializing Inference API")
        self.client = None

    async def setup(self, device: Optional[str] = None):
        """Setup method required by LitAPI - initialize HTTP client"""
        self._device = device
        self.client = httpx.AsyncClient(
            base_url="http://localhost:8002",  # We'll need to make this configurable
            timeout=60.0
        )
        self.logger.info(f"Inference API setup completed on device: {device}")

    async def predict(self, x: str, **kwargs) -> AsyncIterator[str]:
        """
        Main prediction method required by LitAPI.
        Always yields, either chunks in streaming mode or complete response in non-streaming mode.
        """
        if self.stream:
            async for chunk in self.generate_stream(x, **kwargs):
                yield chunk
        else:
            response = await self.generate_response(x, **kwargs)
            yield response

    def decode_request(self, request: Any, **kwargs) -> str:
        """Convert the request payload to input format."""
        if isinstance(request, dict) and "prompt" in request:
            return request["prompt"]
        return request

    def encode_response(self, output: AsyncIterator[str], **kwargs) -> AsyncIterator[Dict[str, str]]:
        """Convert the model output to a response payload."""
        async def wrapper():
            async for chunk in output:
                yield {"generated_text": chunk}
        return wrapper()

    async def generate_response(
            self,
            prompt: str,
            system_message: Optional[str] = None,
            max_new_tokens: Optional[int] = None
    ) -> str:
        """Generate a complete response by forwarding the request to the LLM Server."""
        self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")

        try:
            response = await self.client.post(
                "/api/v1/generate",
                json={
                    "prompt": prompt,
                    "system_message": system_message,
                    "max_new_tokens": max_new_tokens
                }
            )
            response.raise_for_status()
            data = response.json()
            return data["generated_text"]

        except Exception as e:
            self.logger.error(f"Error in generate_response: {str(e)}")
            raise

    async def generate_stream(
            self,
            prompt: str,
            system_message: Optional[str] = None,
            max_new_tokens: Optional[int] = None
    ) -> AsyncIterator[str]:
        """Generate a streaming response by forwarding the request to the LLM Server."""
        self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")

        try:
            async with self.client.stream(
                    "POST",
                    "/api/v1/generate/stream",
                    json={
                        "prompt": prompt,
                        "system_message": system_message,
                        "max_new_tokens": max_new_tokens
                    }
            ) as response:
                response.raise_for_status()
                async for chunk in response.aiter_text():
                    yield chunk

        except Exception as e:
            self.logger.error(f"Error in generate_stream: {str(e)}")
            raise

    async def cleanup(self):
        """Cleanup method - close HTTP client"""
        if self.client:
            await self.client.aclose()