Spaces:
Runtime error
Runtime error
Commit
·
47031d7
1
Parent(s):
a717933
FIRST
Browse files- .gitignore +50 -0
- .idea/.gitignore +3 -0
- .idea/Inference-API.iml +9 -0
- .idea/misc.xml +9 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- Dockerfile +24 -0
- app/__init__.py +0 -0
- app/api.py +159 -0
- app/config.yaml +11 -0
- app/main.py +56 -0
- app/routes.py +118 -0
- app/schemas.py +29 -0
- requirements.txt +22 -0
.gitignore
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Virtual Environment
|
2 |
+
myenv/
|
3 |
+
venv/
|
4 |
+
ENV/
|
5 |
+
|
6 |
+
# Python
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
*.so
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
|
28 |
+
# IDEs and editors
|
29 |
+
.idea/
|
30 |
+
.vscode/
|
31 |
+
*.swp
|
32 |
+
*.swo
|
33 |
+
*~
|
34 |
+
.project
|
35 |
+
.settings/
|
36 |
+
.classpath
|
37 |
+
|
38 |
+
# Logs and databases
|
39 |
+
*.log
|
40 |
+
*.sqlite
|
41 |
+
*.db
|
42 |
+
|
43 |
+
# OS generated files
|
44 |
+
.DS_Store
|
45 |
+
.DS_Store?
|
46 |
+
._*
|
47 |
+
.Spotlight-V100
|
48 |
+
.Trashes
|
49 |
+
ehthumbs.db
|
50 |
+
Thumbs.db
|
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/Inference-API.iml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="JAVA_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
4 |
+
<exclude-output />
|
5 |
+
<content url="file://$MODULE_DIR$" />
|
6 |
+
<orderEntry type="inheritedJdk" />
|
7 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
8 |
+
</component>
|
9 |
+
</module>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Black">
|
4 |
+
<option name="sdkName" value="Python 3.13 (Inference-API)" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="Python 3.13 (Inference-API)" project-jdk-type="Python SDK">
|
7 |
+
<output url="file://$PROJECT_DIR$/out" />
|
8 |
+
</component>
|
9 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/Inference-API.iml" filepath="$PROJECT_DIR$/.idea/Inference-API.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
Dockerfile
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Python 3.12 slim image as base
|
2 |
+
FROM python:3.12-slim
|
3 |
+
|
4 |
+
# Set working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy requirements first to leverage Docker cache
|
8 |
+
COPY requirements.txt .
|
9 |
+
|
10 |
+
# Install dependencies
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
# Copy the application code
|
14 |
+
COPY app/ ./main/
|
15 |
+
|
16 |
+
# Set environment variables
|
17 |
+
ENV PYTHONPATH=/app
|
18 |
+
ENV PYTHONUNBUFFERED=1
|
19 |
+
|
20 |
+
# Expose the port your application runs on.
|
21 |
+
EXPOSE 7680
|
22 |
+
|
23 |
+
# Command to run the application
|
24 |
+
CMD ["python", "-m", "app.main"]
|
app/__init__.py
ADDED
File without changes
|
app/api.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import httpx
|
2 |
+
from typing import Optional, Iterator, List, Dict, Union
|
3 |
+
import logging
|
4 |
+
|
5 |
+
class InferenceApi:
|
6 |
+
def __init__(self, config: dict):
|
7 |
+
"""Initialize the Inference API with configuration."""
|
8 |
+
self.logger = logging.getLogger(__name__)
|
9 |
+
self.logger.info("Initializing Inference API")
|
10 |
+
|
11 |
+
# Get base URL from config
|
12 |
+
self.base_url = config["llm_server"]["base_url"]
|
13 |
+
self.timeout = config["llm_server"].get("timeout", 60)
|
14 |
+
|
15 |
+
# Initialize HTTP client
|
16 |
+
self.client = httpx.AsyncClient(
|
17 |
+
base_url=self.base_url,
|
18 |
+
timeout=self.timeout
|
19 |
+
)
|
20 |
+
|
21 |
+
self.logger.info("Inference API initialized successfully")
|
22 |
+
|
23 |
+
async def generate_response(
|
24 |
+
self,
|
25 |
+
prompt: str,
|
26 |
+
system_message: Optional[str] = None,
|
27 |
+
max_new_tokens: Optional[int] = None
|
28 |
+
) -> str:
|
29 |
+
"""
|
30 |
+
Generate a complete response by forwarding the request to the LLM Server.
|
31 |
+
"""
|
32 |
+
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
|
33 |
+
|
34 |
+
try:
|
35 |
+
response = await self.client.post(
|
36 |
+
"/api/v1/generate",
|
37 |
+
json={
|
38 |
+
"prompt": prompt,
|
39 |
+
"system_message": system_message,
|
40 |
+
"max_new_tokens": max_new_tokens
|
41 |
+
}
|
42 |
+
)
|
43 |
+
response.raise_for_status()
|
44 |
+
data = response.json()
|
45 |
+
return data["generated_text"]
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
self.logger.error(f"Error in generate_response: {str(e)}")
|
49 |
+
raise
|
50 |
+
|
51 |
+
async def generate_stream(
|
52 |
+
self,
|
53 |
+
prompt: str,
|
54 |
+
system_message: Optional[str] = None,
|
55 |
+
max_new_tokens: Optional[int] = None
|
56 |
+
) -> Iterator[str]:
|
57 |
+
"""
|
58 |
+
Generate a streaming response by forwarding the request to the LLM Server.
|
59 |
+
"""
|
60 |
+
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
|
61 |
+
|
62 |
+
try:
|
63 |
+
async with self.client.stream(
|
64 |
+
"POST",
|
65 |
+
"/api/v1/generate/stream",
|
66 |
+
json={
|
67 |
+
"prompt": prompt,
|
68 |
+
"system_message": system_message,
|
69 |
+
"max_new_tokens": max_new_tokens
|
70 |
+
}
|
71 |
+
) as response:
|
72 |
+
response.raise_for_status()
|
73 |
+
async for chunk in response.aiter_text():
|
74 |
+
yield chunk
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
self.logger.error(f"Error in generate_stream: {str(e)}")
|
78 |
+
raise
|
79 |
+
|
80 |
+
async def generate_embedding(self, text: str) -> List[float]:
|
81 |
+
"""
|
82 |
+
Generate embedding by forwarding the request to the LLM Server.
|
83 |
+
"""
|
84 |
+
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
|
85 |
+
|
86 |
+
try:
|
87 |
+
response = await self.client.post(
|
88 |
+
"/api/v1/embedding",
|
89 |
+
json={"text": text}
|
90 |
+
)
|
91 |
+
response.raise_for_status()
|
92 |
+
data = response.json()
|
93 |
+
return data["embedding"]
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
self.logger.error(f"Error in generate_embedding: {str(e)}")
|
97 |
+
raise
|
98 |
+
|
99 |
+
async def check_system_status(self) -> Dict[str, Union[Dict, str]]:
|
100 |
+
"""
|
101 |
+
Get system status from the LLM Server.
|
102 |
+
"""
|
103 |
+
try:
|
104 |
+
response = await self.client.get("/api/v1/system/status")
|
105 |
+
response.raise_for_status()
|
106 |
+
return response.json()
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
self.logger.error(f"Error getting system status: {str(e)}")
|
110 |
+
raise
|
111 |
+
|
112 |
+
async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]:
|
113 |
+
"""
|
114 |
+
Get system validation status from the LLM Server.
|
115 |
+
"""
|
116 |
+
try:
|
117 |
+
response = await self.client.get("/api/v1/system/validate")
|
118 |
+
response.raise_for_status()
|
119 |
+
return response.json()
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
self.logger.error(f"Error validating system: {str(e)}")
|
123 |
+
raise
|
124 |
+
|
125 |
+
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
|
126 |
+
"""
|
127 |
+
Initialize a model on the LLM Server.
|
128 |
+
"""
|
129 |
+
try:
|
130 |
+
response = await self.client.post(
|
131 |
+
"/api/v1/model/initialize",
|
132 |
+
params={"model_name": model_name} if model_name else None
|
133 |
+
)
|
134 |
+
response.raise_for_status()
|
135 |
+
return response.json()
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
self.logger.error(f"Error initializing model: {str(e)}")
|
139 |
+
raise
|
140 |
+
|
141 |
+
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
|
142 |
+
"""
|
143 |
+
Initialize an embedding model on the LLM Server.
|
144 |
+
"""
|
145 |
+
try:
|
146 |
+
response = await self.client.post(
|
147 |
+
"/api/v1/model/initialize/embedding",
|
148 |
+
params={"model_name": model_name} if model_name else None
|
149 |
+
)
|
150 |
+
response.raise_for_status()
|
151 |
+
return response.json()
|
152 |
+
|
153 |
+
except Exception as e:
|
154 |
+
self.logger.error(f"Error initializing embedding model: {str(e)}")
|
155 |
+
raise
|
156 |
+
|
157 |
+
async def close(self):
|
158 |
+
"""Close the HTTP client session."""
|
159 |
+
await self.client.aclose()
|
app/config.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
server:
|
2 |
+
port: 8001
|
3 |
+
timeout: 60
|
4 |
+
|
5 |
+
llm_server:
|
6 |
+
base_url: "https://teamgenki-llmserver.hf.space:7680" # URL of your LLM Server
|
7 |
+
timeout: 60 # Timeout for requests to LLM Server
|
8 |
+
|
9 |
+
logging:
|
10 |
+
level: "INFO"
|
11 |
+
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
app/main.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM Inference Server main application using LitServe framework.
|
3 |
+
"""
|
4 |
+
import litserve as ls
|
5 |
+
import yaml
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
from .routes import router, init_router
|
9 |
+
|
10 |
+
def setup_logging():
|
11 |
+
"""Set up basic logging configuration"""
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
15 |
+
)
|
16 |
+
return logging.getLogger(__name__)
|
17 |
+
|
18 |
+
def load_config():
|
19 |
+
"""Load configuration from config.yaml"""
|
20 |
+
config_path = Path(__file__).parent / "config.yaml"
|
21 |
+
with open(config_path) as f:
|
22 |
+
return yaml.safe_load(f)
|
23 |
+
|
24 |
+
def main():
|
25 |
+
"""Main function to set up and run the inference server."""
|
26 |
+
logger = setup_logging()
|
27 |
+
|
28 |
+
try:
|
29 |
+
# Load configuration
|
30 |
+
config = load_config()
|
31 |
+
|
32 |
+
# Initialize the router with our config
|
33 |
+
init_router(config)
|
34 |
+
|
35 |
+
# Create LitServer instance
|
36 |
+
server = ls.LitServer(
|
37 |
+
timeout=config.get("server", {}).get("timeout", 60),
|
38 |
+
max_batch_size=1,
|
39 |
+
track_requests=True
|
40 |
+
)
|
41 |
+
|
42 |
+
# Add our routes to the server's FastAPI app
|
43 |
+
server.app.include_router(router, prefix="/api/v1")
|
44 |
+
|
45 |
+
# Get port from config or use default
|
46 |
+
port = config.get("server", {}).get("port", 8001)
|
47 |
+
|
48 |
+
logger.info(f"Starting server on port {port}")
|
49 |
+
server.run(port=port)
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(f"Server initialization failed: {str(e)}")
|
53 |
+
raise
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
main()
|
app/routes.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException
|
2 |
+
from typing import Optional
|
3 |
+
from .api import InferenceApi
|
4 |
+
from .schemas import (
|
5 |
+
GenerateRequest,
|
6 |
+
EmbeddingRequest,
|
7 |
+
EmbeddingResponse,
|
8 |
+
SystemStatusResponse,
|
9 |
+
ValidationResponse
|
10 |
+
)
|
11 |
+
import logging
|
12 |
+
|
13 |
+
router = APIRouter()
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
api = None
|
16 |
+
|
17 |
+
def init_router(config: dict):
|
18 |
+
"""Initialize router with config and Inference API instance"""
|
19 |
+
global api
|
20 |
+
api = InferenceApi(config)
|
21 |
+
logger.info("Router initialized with Inference API instance")
|
22 |
+
|
23 |
+
@router.post("/generate")
|
24 |
+
async def generate_text(request: GenerateRequest):
|
25 |
+
"""Generate text response from prompt"""
|
26 |
+
logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
|
27 |
+
try:
|
28 |
+
response = await api.generate_response(
|
29 |
+
prompt=request.prompt,
|
30 |
+
system_message=request.system_message,
|
31 |
+
max_new_tokens=request.max_new_tokens
|
32 |
+
)
|
33 |
+
logger.info("Successfully generated response")
|
34 |
+
return {"generated_text": response}
|
35 |
+
except Exception as e:
|
36 |
+
logger.error(f"Error in generate_text endpoint: {str(e)}")
|
37 |
+
raise HTTPException(status_code=500, detail=str(e))
|
38 |
+
|
39 |
+
@router.post("/generate/stream")
|
40 |
+
async def generate_stream(request: GenerateRequest):
|
41 |
+
"""Generate streaming text response from prompt"""
|
42 |
+
logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
|
43 |
+
try:
|
44 |
+
return api.generate_stream(
|
45 |
+
prompt=request.prompt,
|
46 |
+
system_message=request.system_message,
|
47 |
+
max_new_tokens=request.max_new_tokens
|
48 |
+
)
|
49 |
+
except Exception as e:
|
50 |
+
logger.error(f"Error in generate_stream endpoint: {str(e)}")
|
51 |
+
raise HTTPException(status_code=500, detail=str(e))
|
52 |
+
|
53 |
+
@router.post("/embedding", response_model=EmbeddingResponse)
|
54 |
+
async def generate_embedding(request: EmbeddingRequest):
|
55 |
+
"""Generate embedding vector from text"""
|
56 |
+
logger.info(f"Received embedding request for text: {request.text[:50]}...")
|
57 |
+
try:
|
58 |
+
embedding = await api.generate_embedding(request.text)
|
59 |
+
logger.info(f"Successfully generated embedding of dimension {len(embedding)}")
|
60 |
+
return EmbeddingResponse(
|
61 |
+
embedding=embedding,
|
62 |
+
dimension=len(embedding)
|
63 |
+
)
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Error in generate_embedding endpoint: {str(e)}")
|
66 |
+
raise HTTPException(status_code=500, detail=str(e))
|
67 |
+
|
68 |
+
@router.get("/system/status",
|
69 |
+
response_model=SystemStatusResponse,
|
70 |
+
summary="Check System Status",
|
71 |
+
description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information")
|
72 |
+
async def check_system():
|
73 |
+
"""Get system status from LLM Server"""
|
74 |
+
try:
|
75 |
+
return await api.check_system_status()
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error checking system status: {str(e)}")
|
78 |
+
raise HTTPException(status_code=500, detail=str(e))
|
79 |
+
|
80 |
+
@router.get("/system/validate",
|
81 |
+
response_model=ValidationResponse,
|
82 |
+
summary="Validate System Configuration",
|
83 |
+
description="Validates system configuration, folders, and model setup")
|
84 |
+
async def validate_system():
|
85 |
+
"""Get system validation status from LLM Server"""
|
86 |
+
try:
|
87 |
+
return await api.validate_system()
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Error validating system: {str(e)}")
|
90 |
+
raise HTTPException(status_code=500, detail=str(e))
|
91 |
+
|
92 |
+
@router.post("/model/initialize",
|
93 |
+
summary="Initialize default or specified model",
|
94 |
+
description="Initialize model for use. Uses default model from config if none specified.")
|
95 |
+
async def initialize_model(model_name: Optional[str] = None):
|
96 |
+
"""Initialize a model for use"""
|
97 |
+
try:
|
98 |
+
return await api.initialize_model(model_name)
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Error initializing model: {str(e)}")
|
101 |
+
raise HTTPException(status_code=500, detail=str(e))
|
102 |
+
|
103 |
+
@router.post("/model/initialize/embedding",
|
104 |
+
summary="Initialize embedding model",
|
105 |
+
description="Initialize a separate model specifically for generating embeddings")
|
106 |
+
async def initialize_embedding_model(model_name: Optional[str] = None):
|
107 |
+
"""Initialize a model specifically for embeddings"""
|
108 |
+
try:
|
109 |
+
return await api.initialize_embedding_model(model_name)
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Error initializing embedding model: {str(e)}")
|
112 |
+
raise HTTPException(status_code=500, detail=str(e))
|
113 |
+
|
114 |
+
@router.on_event("shutdown")
|
115 |
+
async def shutdown_event():
|
116 |
+
"""Clean up resources on shutdown"""
|
117 |
+
if api:
|
118 |
+
await api.close()
|
app/schemas.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import Optional, List, Dict, Union
|
3 |
+
|
4 |
+
class GenerateRequest(BaseModel):
|
5 |
+
prompt: str
|
6 |
+
system_message: Optional[str] = None
|
7 |
+
max_new_tokens: Optional[int] = None
|
8 |
+
|
9 |
+
class EmbeddingRequest(BaseModel):
|
10 |
+
text: str
|
11 |
+
|
12 |
+
class EmbeddingResponse(BaseModel):
|
13 |
+
embedding: List[float]
|
14 |
+
dimension: int
|
15 |
+
|
16 |
+
class SystemStatusResponse(BaseModel):
|
17 |
+
"""Pydantic model for system status response"""
|
18 |
+
cpu: Optional[Dict[str, Union[float, str]]] = None
|
19 |
+
memory: Optional[Dict[str, Union[float, str]]] = None
|
20 |
+
gpu: Optional[Dict[str, Union[bool, str, float]]] = None
|
21 |
+
storage: Optional[Dict[str, str]] = None
|
22 |
+
model: Optional[Dict[str, Union[bool, str]]] = None
|
23 |
+
|
24 |
+
class ValidationResponse(BaseModel):
|
25 |
+
config_validation: Dict[str, bool]
|
26 |
+
model_validation: Dict[str, bool]
|
27 |
+
folder_validation: Dict[str, bool]
|
28 |
+
overall_status: str
|
29 |
+
issues: List[str]
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.7.0
|
2 |
+
anyio==4.8.0
|
3 |
+
certifi==2024.12.14
|
4 |
+
click==8.1.8
|
5 |
+
fastapi==0.115.6
|
6 |
+
h11==0.14.0
|
7 |
+
httpcore==1.0.7
|
8 |
+
httptools==0.6.4
|
9 |
+
httpx==0.28.1
|
10 |
+
idna==3.10
|
11 |
+
litserve==0.2.5
|
12 |
+
pydantic==2.10.4
|
13 |
+
pydantic_core==2.27.2
|
14 |
+
python-dotenv==1.0.1
|
15 |
+
PyYAML==6.0.2
|
16 |
+
sniffio==1.3.1
|
17 |
+
starlette==0.41.3
|
18 |
+
typing_extensions==4.12.2
|
19 |
+
uvicorn==0.34.0
|
20 |
+
uvloop==0.21.0
|
21 |
+
watchfiles==1.0.3
|
22 |
+
websockets==14.1
|