File size: 4,704 Bytes
47031d7
a4e24d4
47031d7
a4e24d4
47031d7
a4e24d4
47031d7
 
a4e24d4
47031d7
 
 
 
 
 
a4e24d4
47031d7
a4e24d4
 
 
 
 
 
47031d7
 
 
 
a4e24d4
47031d7
a4e24d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47031d7
 
 
 
 
 
 
a4e24d4
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e24d4
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e24d4
47031d7
a4e24d4
 
 
 
47031d7
a4e24d4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import httpx
from typing import Optional, Iterator, Union, Any
import logging
from litserve import LitAPI

class InferenceApi(LitAPI):
    def __init__(self, config: dict):
        """Initialize the Inference API with configuration."""
        super().__init__()
        self.logger = logging.getLogger(__name__)
        self.logger.info("Initializing Inference API")

        # Get base URL from config
        self.base_url = config["llm_server"]["base_url"]
        self.timeout = config["llm_server"].get("timeout", 60)
        self.client = None  # Will be initialized in setup()

        # Set request timeout from config
        self.request_timeout = float(self.timeout)

    async def setup(self, device: Optional[str] = None):
        """Setup method required by LitAPI - initialize HTTP client"""
        self._device = device  # Store device as required by LitAPI
        self.client = httpx.AsyncClient(
            base_url=self.base_url,
            timeout=self.timeout
        )
        self.logger.info(f"Inference API setup completed on device: {device}")

    async def predict(self, x: str, **kwargs) -> Union[str, Iterator[str]]:
        """
        Main prediction method required by LitAPI.
        If streaming is enabled, yields chunks; otherwise returns complete response.
        """
        if self.stream:
            async for chunk in self.generate_stream(x, **kwargs):
                yield chunk
        else:
            return await self.generate_response(x, **kwargs)

    def decode_request(self, request: Any, **kwargs) -> str:
        """Convert the request payload to input format."""
        # For our case, we expect the request to be text
        if isinstance(request, dict) and "prompt" in request:
            return request["prompt"]
        return request

    def encode_response(self, output: Union[str, Iterator[str]], **kwargs) -> Union[str, Iterator[str]]:
        """Convert the model output to a response payload."""
        if self.stream:
            # For streaming, yield each chunk wrapped in a dict
            async def stream_wrapper():
                async for chunk in output:
                    yield {"generated_text": chunk}
        else:
            # For non-streaming, return complete response
            return {"generated_text": output}

    async def generate_response(
            self,
            prompt: str,
            system_message: Optional[str] = None,
            max_new_tokens: Optional[int] = None
    ) -> str:
        """Generate a complete response by forwarding the request to the LLM Server."""
        self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")

        try:
            response = await self.client.post(
                "/api/v1/generate",
                json={
                    "prompt": prompt,
                    "system_message": system_message,
                    "max_new_tokens": max_new_tokens
                }
            )
            response.raise_for_status()
            data = response.json()
            return data["generated_text"]

        except Exception as e:
            self.logger.error(f"Error in generate_response: {str(e)}")
            raise

    async def generate_stream(
            self,
            prompt: str,
            system_message: Optional[str] = None,
            max_new_tokens: Optional[int] = None
    ) -> Iterator[str]:
        """Generate a streaming response by forwarding the request to the LLM Server."""
        self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")

        try:
            async with self.client.stream(
                    "POST",
                    "/api/v1/generate/stream",
                    json={
                        "prompt": prompt,
                        "system_message": system_message,
                        "max_new_tokens": max_new_tokens
                    }
            ) as response:
                response.raise_for_status()
                async for chunk in response.aiter_text():
                    yield chunk

        except Exception as e:
            self.logger.error(f"Error in generate_stream: {str(e)}")
            raise

    # ... [rest of the methods remain the same: generate_embedding, check_system_status, etc.]

    async def cleanup(self):
        """Cleanup method - close HTTP client"""
        if self.client:
            await self.client.aclose()

    def log(self, key: str, value: Any):
        """Override log method to use our logger if queue not set"""
        if self._logger_queue is None:
            self.logger.info(f"Log event: {key}={value}")
        else:
            super().log(key, value)