Spaces:
Runtime error
Runtime error
File size: 5,356 Bytes
47031d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import httpx
from typing import Optional, Iterator, List, Dict, Union
import logging
class InferenceApi:
def __init__(self, config: dict):
"""Initialize the Inference API with configuration."""
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing Inference API")
# Get base URL from config
self.base_url = config["llm_server"]["base_url"]
self.timeout = config["llm_server"].get("timeout", 60)
# Initialize HTTP client
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout
)
self.logger.info("Inference API initialized successfully")
async def generate_response(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> str:
"""
Generate a complete response by forwarding the request to the LLM Server.
"""
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
try:
response = await self.client.post(
"/api/v1/generate",
json={
"prompt": prompt,
"system_message": system_message,
"max_new_tokens": max_new_tokens
}
)
response.raise_for_status()
data = response.json()
return data["generated_text"]
except Exception as e:
self.logger.error(f"Error in generate_response: {str(e)}")
raise
async def generate_stream(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> Iterator[str]:
"""
Generate a streaming response by forwarding the request to the LLM Server.
"""
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
try:
async with self.client.stream(
"POST",
"/api/v1/generate/stream",
json={
"prompt": prompt,
"system_message": system_message,
"max_new_tokens": max_new_tokens
}
) as response:
response.raise_for_status()
async for chunk in response.aiter_text():
yield chunk
except Exception as e:
self.logger.error(f"Error in generate_stream: {str(e)}")
raise
async def generate_embedding(self, text: str) -> List[float]:
"""
Generate embedding by forwarding the request to the LLM Server.
"""
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
try:
response = await self.client.post(
"/api/v1/embedding",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data["embedding"]
except Exception as e:
self.logger.error(f"Error in generate_embedding: {str(e)}")
raise
async def check_system_status(self) -> Dict[str, Union[Dict, str]]:
"""
Get system status from the LLM Server.
"""
try:
response = await self.client.get("/api/v1/system/status")
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error getting system status: {str(e)}")
raise
async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]:
"""
Get system validation status from the LLM Server.
"""
try:
response = await self.client.get("/api/v1/system/validate")
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error validating system: {str(e)}")
raise
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
"""
Initialize a model on the LLM Server.
"""
try:
response = await self.client.post(
"/api/v1/model/initialize",
params={"model_name": model_name} if model_name else None
)
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error initializing model: {str(e)}")
raise
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
"""
Initialize an embedding model on the LLM Server.
"""
try:
response = await self.client.post(
"/api/v1/model/initialize/embedding",
params={"model_name": model_name} if model_name else None
)
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error initializing embedding model: {str(e)}")
raise
async def close(self):
"""Close the HTTP client session."""
await self.client.aclose() |