AurelioAguirre commited on
Commit
47031d7
·
1 Parent(s): a717933
.gitignore ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual Environment
2
+ myenv/
3
+ venv/
4
+ ENV/
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # IDEs and editors
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+ *~
34
+ .project
35
+ .settings/
36
+ .classpath
37
+
38
+ # Logs and databases
39
+ *.log
40
+ *.sqlite
41
+ *.db
42
+
43
+ # OS generated files
44
+ .DS_Store
45
+ .DS_Store?
46
+ ._*
47
+ .Spotlight-V100
48
+ .Trashes
49
+ ehthumbs.db
50
+ Thumbs.db
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/Inference-API.iml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="JAVA_MODULE" version="4">
3
+ <component name="NewModuleRootManager" inherit-compiler-output="true">
4
+ <exclude-output />
5
+ <content url="file://$MODULE_DIR$" />
6
+ <orderEntry type="inheritedJdk" />
7
+ <orderEntry type="sourceFolder" forTests="false" />
8
+ </component>
9
+ </module>
.idea/misc.xml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.13 (Inference-API)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="Python 3.13 (Inference-API)" project-jdk-type="Python SDK">
7
+ <output url="file://$PROJECT_DIR$/out" />
8
+ </component>
9
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Inference-API.iml" filepath="$PROJECT_DIR$/.idea/Inference-API.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.12 slim image as base
2
+ FROM python:3.12-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements first to leverage Docker cache
8
+ COPY requirements.txt .
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the application code
14
+ COPY app/ ./main/
15
+
16
+ # Set environment variables
17
+ ENV PYTHONPATH=/app
18
+ ENV PYTHONUNBUFFERED=1
19
+
20
+ # Expose the port your application runs on.
21
+ EXPOSE 7680
22
+
23
+ # Command to run the application
24
+ CMD ["python", "-m", "app.main"]
app/__init__.py ADDED
File without changes
app/api.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from typing import Optional, Iterator, List, Dict, Union
3
+ import logging
4
+
5
+ class InferenceApi:
6
+ def __init__(self, config: dict):
7
+ """Initialize the Inference API with configuration."""
8
+ self.logger = logging.getLogger(__name__)
9
+ self.logger.info("Initializing Inference API")
10
+
11
+ # Get base URL from config
12
+ self.base_url = config["llm_server"]["base_url"]
13
+ self.timeout = config["llm_server"].get("timeout", 60)
14
+
15
+ # Initialize HTTP client
16
+ self.client = httpx.AsyncClient(
17
+ base_url=self.base_url,
18
+ timeout=self.timeout
19
+ )
20
+
21
+ self.logger.info("Inference API initialized successfully")
22
+
23
+ async def generate_response(
24
+ self,
25
+ prompt: str,
26
+ system_message: Optional[str] = None,
27
+ max_new_tokens: Optional[int] = None
28
+ ) -> str:
29
+ """
30
+ Generate a complete response by forwarding the request to the LLM Server.
31
+ """
32
+ self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
33
+
34
+ try:
35
+ response = await self.client.post(
36
+ "/api/v1/generate",
37
+ json={
38
+ "prompt": prompt,
39
+ "system_message": system_message,
40
+ "max_new_tokens": max_new_tokens
41
+ }
42
+ )
43
+ response.raise_for_status()
44
+ data = response.json()
45
+ return data["generated_text"]
46
+
47
+ except Exception as e:
48
+ self.logger.error(f"Error in generate_response: {str(e)}")
49
+ raise
50
+
51
+ async def generate_stream(
52
+ self,
53
+ prompt: str,
54
+ system_message: Optional[str] = None,
55
+ max_new_tokens: Optional[int] = None
56
+ ) -> Iterator[str]:
57
+ """
58
+ Generate a streaming response by forwarding the request to the LLM Server.
59
+ """
60
+ self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
61
+
62
+ try:
63
+ async with self.client.stream(
64
+ "POST",
65
+ "/api/v1/generate/stream",
66
+ json={
67
+ "prompt": prompt,
68
+ "system_message": system_message,
69
+ "max_new_tokens": max_new_tokens
70
+ }
71
+ ) as response:
72
+ response.raise_for_status()
73
+ async for chunk in response.aiter_text():
74
+ yield chunk
75
+
76
+ except Exception as e:
77
+ self.logger.error(f"Error in generate_stream: {str(e)}")
78
+ raise
79
+
80
+ async def generate_embedding(self, text: str) -> List[float]:
81
+ """
82
+ Generate embedding by forwarding the request to the LLM Server.
83
+ """
84
+ self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
85
+
86
+ try:
87
+ response = await self.client.post(
88
+ "/api/v1/embedding",
89
+ json={"text": text}
90
+ )
91
+ response.raise_for_status()
92
+ data = response.json()
93
+ return data["embedding"]
94
+
95
+ except Exception as e:
96
+ self.logger.error(f"Error in generate_embedding: {str(e)}")
97
+ raise
98
+
99
+ async def check_system_status(self) -> Dict[str, Union[Dict, str]]:
100
+ """
101
+ Get system status from the LLM Server.
102
+ """
103
+ try:
104
+ response = await self.client.get("/api/v1/system/status")
105
+ response.raise_for_status()
106
+ return response.json()
107
+
108
+ except Exception as e:
109
+ self.logger.error(f"Error getting system status: {str(e)}")
110
+ raise
111
+
112
+ async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]:
113
+ """
114
+ Get system validation status from the LLM Server.
115
+ """
116
+ try:
117
+ response = await self.client.get("/api/v1/system/validate")
118
+ response.raise_for_status()
119
+ return response.json()
120
+
121
+ except Exception as e:
122
+ self.logger.error(f"Error validating system: {str(e)}")
123
+ raise
124
+
125
+ async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
126
+ """
127
+ Initialize a model on the LLM Server.
128
+ """
129
+ try:
130
+ response = await self.client.post(
131
+ "/api/v1/model/initialize",
132
+ params={"model_name": model_name} if model_name else None
133
+ )
134
+ response.raise_for_status()
135
+ return response.json()
136
+
137
+ except Exception as e:
138
+ self.logger.error(f"Error initializing model: {str(e)}")
139
+ raise
140
+
141
+ async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
142
+ """
143
+ Initialize an embedding model on the LLM Server.
144
+ """
145
+ try:
146
+ response = await self.client.post(
147
+ "/api/v1/model/initialize/embedding",
148
+ params={"model_name": model_name} if model_name else None
149
+ )
150
+ response.raise_for_status()
151
+ return response.json()
152
+
153
+ except Exception as e:
154
+ self.logger.error(f"Error initializing embedding model: {str(e)}")
155
+ raise
156
+
157
+ async def close(self):
158
+ """Close the HTTP client session."""
159
+ await self.client.aclose()
app/config.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ server:
2
+ port: 8001
3
+ timeout: 60
4
+
5
+ llm_server:
6
+ base_url: "https://teamgenki-llmserver.hf.space:7680" # URL of your LLM Server
7
+ timeout: 60 # Timeout for requests to LLM Server
8
+
9
+ logging:
10
+ level: "INFO"
11
+ format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
app/main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Inference Server main application using LitServe framework.
3
+ """
4
+ import litserve as ls
5
+ import yaml
6
+ import logging
7
+ from pathlib import Path
8
+ from .routes import router, init_router
9
+
10
+ def setup_logging():
11
+ """Set up basic logging configuration"""
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
+ )
16
+ return logging.getLogger(__name__)
17
+
18
+ def load_config():
19
+ """Load configuration from config.yaml"""
20
+ config_path = Path(__file__).parent / "config.yaml"
21
+ with open(config_path) as f:
22
+ return yaml.safe_load(f)
23
+
24
+ def main():
25
+ """Main function to set up and run the inference server."""
26
+ logger = setup_logging()
27
+
28
+ try:
29
+ # Load configuration
30
+ config = load_config()
31
+
32
+ # Initialize the router with our config
33
+ init_router(config)
34
+
35
+ # Create LitServer instance
36
+ server = ls.LitServer(
37
+ timeout=config.get("server", {}).get("timeout", 60),
38
+ max_batch_size=1,
39
+ track_requests=True
40
+ )
41
+
42
+ # Add our routes to the server's FastAPI app
43
+ server.app.include_router(router, prefix="/api/v1")
44
+
45
+ # Get port from config or use default
46
+ port = config.get("server", {}).get("port", 8001)
47
+
48
+ logger.info(f"Starting server on port {port}")
49
+ server.run(port=port)
50
+
51
+ except Exception as e:
52
+ logger.error(f"Server initialization failed: {str(e)}")
53
+ raise
54
+
55
+ if __name__ == "__main__":
56
+ main()
app/routes.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from typing import Optional
3
+ from .api import InferenceApi
4
+ from .schemas import (
5
+ GenerateRequest,
6
+ EmbeddingRequest,
7
+ EmbeddingResponse,
8
+ SystemStatusResponse,
9
+ ValidationResponse
10
+ )
11
+ import logging
12
+
13
+ router = APIRouter()
14
+ logger = logging.getLogger(__name__)
15
+ api = None
16
+
17
+ def init_router(config: dict):
18
+ """Initialize router with config and Inference API instance"""
19
+ global api
20
+ api = InferenceApi(config)
21
+ logger.info("Router initialized with Inference API instance")
22
+
23
+ @router.post("/generate")
24
+ async def generate_text(request: GenerateRequest):
25
+ """Generate text response from prompt"""
26
+ logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
27
+ try:
28
+ response = await api.generate_response(
29
+ prompt=request.prompt,
30
+ system_message=request.system_message,
31
+ max_new_tokens=request.max_new_tokens
32
+ )
33
+ logger.info("Successfully generated response")
34
+ return {"generated_text": response}
35
+ except Exception as e:
36
+ logger.error(f"Error in generate_text endpoint: {str(e)}")
37
+ raise HTTPException(status_code=500, detail=str(e))
38
+
39
+ @router.post("/generate/stream")
40
+ async def generate_stream(request: GenerateRequest):
41
+ """Generate streaming text response from prompt"""
42
+ logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
43
+ try:
44
+ return api.generate_stream(
45
+ prompt=request.prompt,
46
+ system_message=request.system_message,
47
+ max_new_tokens=request.max_new_tokens
48
+ )
49
+ except Exception as e:
50
+ logger.error(f"Error in generate_stream endpoint: {str(e)}")
51
+ raise HTTPException(status_code=500, detail=str(e))
52
+
53
+ @router.post("/embedding", response_model=EmbeddingResponse)
54
+ async def generate_embedding(request: EmbeddingRequest):
55
+ """Generate embedding vector from text"""
56
+ logger.info(f"Received embedding request for text: {request.text[:50]}...")
57
+ try:
58
+ embedding = await api.generate_embedding(request.text)
59
+ logger.info(f"Successfully generated embedding of dimension {len(embedding)}")
60
+ return EmbeddingResponse(
61
+ embedding=embedding,
62
+ dimension=len(embedding)
63
+ )
64
+ except Exception as e:
65
+ logger.error(f"Error in generate_embedding endpoint: {str(e)}")
66
+ raise HTTPException(status_code=500, detail=str(e))
67
+
68
+ @router.get("/system/status",
69
+ response_model=SystemStatusResponse,
70
+ summary="Check System Status",
71
+ description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information")
72
+ async def check_system():
73
+ """Get system status from LLM Server"""
74
+ try:
75
+ return await api.check_system_status()
76
+ except Exception as e:
77
+ logger.error(f"Error checking system status: {str(e)}")
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
+ @router.get("/system/validate",
81
+ response_model=ValidationResponse,
82
+ summary="Validate System Configuration",
83
+ description="Validates system configuration, folders, and model setup")
84
+ async def validate_system():
85
+ """Get system validation status from LLM Server"""
86
+ try:
87
+ return await api.validate_system()
88
+ except Exception as e:
89
+ logger.error(f"Error validating system: {str(e)}")
90
+ raise HTTPException(status_code=500, detail=str(e))
91
+
92
+ @router.post("/model/initialize",
93
+ summary="Initialize default or specified model",
94
+ description="Initialize model for use. Uses default model from config if none specified.")
95
+ async def initialize_model(model_name: Optional[str] = None):
96
+ """Initialize a model for use"""
97
+ try:
98
+ return await api.initialize_model(model_name)
99
+ except Exception as e:
100
+ logger.error(f"Error initializing model: {str(e)}")
101
+ raise HTTPException(status_code=500, detail=str(e))
102
+
103
+ @router.post("/model/initialize/embedding",
104
+ summary="Initialize embedding model",
105
+ description="Initialize a separate model specifically for generating embeddings")
106
+ async def initialize_embedding_model(model_name: Optional[str] = None):
107
+ """Initialize a model specifically for embeddings"""
108
+ try:
109
+ return await api.initialize_embedding_model(model_name)
110
+ except Exception as e:
111
+ logger.error(f"Error initializing embedding model: {str(e)}")
112
+ raise HTTPException(status_code=500, detail=str(e))
113
+
114
+ @router.on_event("shutdown")
115
+ async def shutdown_event():
116
+ """Clean up resources on shutdown"""
117
+ if api:
118
+ await api.close()
app/schemas.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, List, Dict, Union
3
+
4
+ class GenerateRequest(BaseModel):
5
+ prompt: str
6
+ system_message: Optional[str] = None
7
+ max_new_tokens: Optional[int] = None
8
+
9
+ class EmbeddingRequest(BaseModel):
10
+ text: str
11
+
12
+ class EmbeddingResponse(BaseModel):
13
+ embedding: List[float]
14
+ dimension: int
15
+
16
+ class SystemStatusResponse(BaseModel):
17
+ """Pydantic model for system status response"""
18
+ cpu: Optional[Dict[str, Union[float, str]]] = None
19
+ memory: Optional[Dict[str, Union[float, str]]] = None
20
+ gpu: Optional[Dict[str, Union[bool, str, float]]] = None
21
+ storage: Optional[Dict[str, str]] = None
22
+ model: Optional[Dict[str, Union[bool, str]]] = None
23
+
24
+ class ValidationResponse(BaseModel):
25
+ config_validation: Dict[str, bool]
26
+ model_validation: Dict[str, bool]
27
+ folder_validation: Dict[str, bool]
28
+ overall_status: str
29
+ issues: List[str]
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.8.0
3
+ certifi==2024.12.14
4
+ click==8.1.8
5
+ fastapi==0.115.6
6
+ h11==0.14.0
7
+ httpcore==1.0.7
8
+ httptools==0.6.4
9
+ httpx==0.28.1
10
+ idna==3.10
11
+ litserve==0.2.5
12
+ pydantic==2.10.4
13
+ pydantic_core==2.27.2
14
+ python-dotenv==1.0.1
15
+ PyYAML==6.0.2
16
+ sniffio==1.3.1
17
+ starlette==0.41.3
18
+ typing_extensions==4.12.2
19
+ uvicorn==0.34.0
20
+ uvloop==0.21.0
21
+ watchfiles==1.0.3
22
+ websockets==14.1