File size: 5,356 Bytes
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import httpx
from typing import Optional, Iterator, List, Dict, Union
import logging

class InferenceApi:
    def __init__(self, config: dict):
        """Initialize the Inference API with configuration."""
        self.logger = logging.getLogger(__name__)
        self.logger.info("Initializing Inference API")

        # Get base URL from config
        self.base_url = config["llm_server"]["base_url"]
        self.timeout = config["llm_server"].get("timeout", 60)

        # Initialize HTTP client
        self.client = httpx.AsyncClient(
            base_url=self.base_url,
            timeout=self.timeout
        )

        self.logger.info("Inference API initialized successfully")

    async def generate_response(
            self,
            prompt: str,
            system_message: Optional[str] = None,
            max_new_tokens: Optional[int] = None
    ) -> str:
        """
        Generate a complete response by forwarding the request to the LLM Server.
        """
        self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")

        try:
            response = await self.client.post(
                "/api/v1/generate",
                json={
                    "prompt": prompt,
                    "system_message": system_message,
                    "max_new_tokens": max_new_tokens
                }
            )
            response.raise_for_status()
            data = response.json()
            return data["generated_text"]

        except Exception as e:
            self.logger.error(f"Error in generate_response: {str(e)}")
            raise

    async def generate_stream(
            self,
            prompt: str,
            system_message: Optional[str] = None,
            max_new_tokens: Optional[int] = None
    ) -> Iterator[str]:
        """
        Generate a streaming response by forwarding the request to the LLM Server.
        """
        self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")

        try:
            async with self.client.stream(
                    "POST",
                    "/api/v1/generate/stream",
                    json={
                        "prompt": prompt,
                        "system_message": system_message,
                        "max_new_tokens": max_new_tokens
                    }
            ) as response:
                response.raise_for_status()
                async for chunk in response.aiter_text():
                    yield chunk

        except Exception as e:
            self.logger.error(f"Error in generate_stream: {str(e)}")
            raise

    async def generate_embedding(self, text: str) -> List[float]:
        """
        Generate embedding by forwarding the request to the LLM Server.
        """
        self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")

        try:
            response = await self.client.post(
                "/api/v1/embedding",
                json={"text": text}
            )
            response.raise_for_status()
            data = response.json()
            return data["embedding"]

        except Exception as e:
            self.logger.error(f"Error in generate_embedding: {str(e)}")
            raise

    async def check_system_status(self) -> Dict[str, Union[Dict, str]]:
        """
        Get system status from the LLM Server.
        """
        try:
            response = await self.client.get("/api/v1/system/status")
            response.raise_for_status()
            return response.json()

        except Exception as e:
            self.logger.error(f"Error getting system status: {str(e)}")
            raise

    async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]:
        """
        Get system validation status from the LLM Server.
        """
        try:
            response = await self.client.get("/api/v1/system/validate")
            response.raise_for_status()
            return response.json()

        except Exception as e:
            self.logger.error(f"Error validating system: {str(e)}")
            raise

    async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
        """
        Initialize a model on the LLM Server.
        """
        try:
            response = await self.client.post(
                "/api/v1/model/initialize",
                params={"model_name": model_name} if model_name else None
            )
            response.raise_for_status()
            return response.json()

        except Exception as e:
            self.logger.error(f"Error initializing model: {str(e)}")
            raise

    async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
        """
        Initialize an embedding model on the LLM Server.
        """
        try:
            response = await self.client.post(
                "/api/v1/model/initialize/embedding",
                params={"model_name": model_name} if model_name else None
            )
            response.raise_for_status()
            return response.json()

        except Exception as e:
            self.logger.error(f"Error initializing embedding model: {str(e)}")
            raise

    async def close(self):
        """Close the HTTP client session."""
        await self.client.aclose()