File size: 4,681 Bytes
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import APIRouter, HTTPException
from typing import Optional
from .api import InferenceApi
from .schemas import (
    GenerateRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    SystemStatusResponse,
    ValidationResponse
)
import logging

router = APIRouter()
logger = logging.getLogger(__name__)
api = None

def init_router(config: dict):
    """Initialize router with config and Inference API instance"""
    global api
    api = InferenceApi(config)
    logger.info("Router initialized with Inference API instance")

@router.post("/generate")
async def generate_text(request: GenerateRequest):
    """Generate text response from prompt"""
    logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
    try:
        response = await api.generate_response(
            prompt=request.prompt,
            system_message=request.system_message,
            max_new_tokens=request.max_new_tokens
        )
        logger.info("Successfully generated response")
        return {"generated_text": response}
    except Exception as e:
        logger.error(f"Error in generate_text endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/generate/stream")
async def generate_stream(request: GenerateRequest):
    """Generate streaming text response from prompt"""
    logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
    try:
        return api.generate_stream(
            prompt=request.prompt,
            system_message=request.system_message,
            max_new_tokens=request.max_new_tokens
        )
    except Exception as e:
        logger.error(f"Error in generate_stream endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/embedding", response_model=EmbeddingResponse)
async def generate_embedding(request: EmbeddingRequest):
    """Generate embedding vector from text"""
    logger.info(f"Received embedding request for text: {request.text[:50]}...")
    try:
        embedding = await api.generate_embedding(request.text)
        logger.info(f"Successfully generated embedding of dimension {len(embedding)}")
        return EmbeddingResponse(
            embedding=embedding,
            dimension=len(embedding)
        )
    except Exception as e:
        logger.error(f"Error in generate_embedding endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.get("/system/status",
            response_model=SystemStatusResponse,
            summary="Check System Status",
            description="Returns comprehensive system status including CPU, Memory, GPU, Storage, and Model information")
async def check_system():
    """Get system status from LLM Server"""
    try:
        return await api.check_system_status()
    except Exception as e:
        logger.error(f"Error checking system status: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.get("/system/validate",
            response_model=ValidationResponse,
            summary="Validate System Configuration",
            description="Validates system configuration, folders, and model setup")
async def validate_system():
    """Get system validation status from LLM Server"""
    try:
        return await api.validate_system()
    except Exception as e:
        logger.error(f"Error validating system: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/model/initialize",
             summary="Initialize default or specified model",
             description="Initialize model for use. Uses default model from config if none specified.")
async def initialize_model(model_name: Optional[str] = None):
    """Initialize a model for use"""
    try:
        return await api.initialize_model(model_name)
    except Exception as e:
        logger.error(f"Error initializing model: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/model/initialize/embedding",
             summary="Initialize embedding model",
             description="Initialize a separate model specifically for generating embeddings")
async def initialize_embedding_model(model_name: Optional[str] = None):
    """Initialize a model specifically for embeddings"""
    try:
        return await api.initialize_embedding_model(model_name)
    except Exception as e:
        logger.error(f"Error initializing embedding model: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@router.on_event("shutdown")
async def shutdown_event():
    """Clean up resources on shutdown"""
    if api:
        await api.close()