File size: 6,578 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import asyncio
import os
import pickle
import sys
import json
import inspect
import threading
import traceback
import uuid
from traceback import print_exception

from pydantic import BaseModel

from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi import Depends
from fastapi.responses import JSONResponse, Response
from fastapi_utils.tasks import repeat_every
from starlette.responses import PlainTextResponse

# Ensure required directories are in sys.path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
if project_root not in sys.path:
    sys.path.append(project_root)
if os.path.dirname('src') not in sys.path:
    sys.path.append('src')


# similar to openai_server/server.py
def verify_api_key(authorization: str = Header(None)) -> None:
    server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY')
    # print("server_api_key: %s %s" % (server_api_key, authorization))
    if server_api_key == 'EMPTY':
        # dummy case since '' cannot be handled
        return
    if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"):
        raise HTTPException(status_code=401, detail="Unauthorized")


app = FastAPI()
check_key = [Depends(verify_api_key)]
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)


class InvalidRequestError(Exception):
    pass


class FunctionRequest(BaseModel):
    function_name: str
    args: tuple
    kwargs: dict
    use_disk: bool = False
    use_pickle: bool = False


@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


@app.exception_handler(Exception)
async def validation_exception_handler(request, exc):
    print_exception(exc)
    exc2 = InvalidRequestError(str(exc))
    return PlainTextResponse(str(exc2), status_code=400)


@app.options("/", dependencies=check_key)
async def options_route():
    return JSONResponse(content="OK")


gen_kwargs = {}
gen_kwargs_lock = threading.Lock()


def initialize_gen_kwargs():
    global gen_kwargs
    with gen_kwargs_lock:  # not strictly required if in global scope
        if not gen_kwargs:
            main_kwargs = json.loads(os.environ['H2OGPT_MAIN_KWARGS'])  # required

            # don't double up LLMs, in pure "document ingest" mode
            main_kwargs['model_lock'] = []
            main_kwargs['base_model'] = ''
            main_kwargs['inference_server'] = ''

            # only for chat part, not used here
            main_kwargs['enable_image'] = False
            main_kwargs['visible_image_models'] = []
            main_kwargs['image_gpu_ids'] = None

            main_kwargs['enable_tts'] = False
            main_kwargs['enable_stt'] = False

            # function server mode only
            main_kwargs['gradio'] = False
            main_kwargs['eval'] = False
            main_kwargs['cli'] = False
            main_kwargs['function'] = True
            # don't double this
            main_kwargs['openai_server'] = False

            # FIXME: Deal with GPU IDs for each caption/ASR/DocTR model, use MIG, etc.

            from gen import main as gen_main
            gen_kwargs = gen_main(**main_kwargs)


# Call the initialization function at startup, but not during import
if 'H2OGPT_MAIN_KWARGS' in os.environ:
    initialize_gen_kwargs()
else:
    print("H2OGPT_MAIN_KWARGS not found in os.environ")


@app.post("/execute_function/", dependencies=check_key)
def execute_function(request: FunctionRequest):
    # Mapping of function names to function objects
    from gpt_langchain import path_to_docs
    from vision.utils_vision import process_file_list
    FUNCTIONS = {
        'path_to_docs': path_to_docs,
        'process_file_list': process_file_list,
    }
    try:
        # Fetch the function from the function map
        func = FUNCTIONS.get(request.function_name)
        if not func:
            raise ValueError("Function not found")

        # use gen_kwargs if needed
        func_names = list(inspect.signature(func).parameters)
        func_kwargs = {k: v for k, v in gen_kwargs.items() if k in func_names and k not in request.kwargs}

        # Call the function with args and kwargs
        result = func(*request.args, **request.kwargs, **func_kwargs)

        if request.use_disk or request.use_pickle:
            # Save the result to a file on the shared disk
            base_path = 'function_results'
            if not os.path.isdir(base_path):
                os.makedirs(base_path)
            file_path = os.path.join(base_path, str(uuid.uuid4()))
            if request.use_pickle:
                file_path += '.pkl'
                with open(file_path, "wb") as f:
                    pickle.dump(result, f)
            else:
                file_path += '.json'
                with open(file_path, "w") as f:
                    json.dump(result, f)
            return {"status": "success", "file_path": os.path.abspath(file_path)}
        else:
            # Return the result directly
            return {"status": "success", "result": result}
    except Exception as e:
        traceback_str = ''.join(traceback.format_exception(e))
        raise HTTPException(status_code=500, detail=traceback_str)
    finally:
        do_check(in_finally=True)


def do_check(in_finally=False):
    health_result = check_some_conditions()
    if not health_result:
        print("Health check failed! Terminating without cleanup (to avoid races) %s..." % in_finally)
        if os.getenv('multiple_workers_gunicorn'):
            os._exit(1)


state_checks = True
if state_checks:
    @app.on_event("startup")
    async def startup_event(verbose=True):
        asyncio.create_task(periodic_health_check(verbose=verbose))


    async def periodic_health_check(verbose=False):
        while True:
            if verbose:
                print("Checking health...")
            await asyncio.sleep(120)  # Wait for 2 minutes between checks
            do_check(in_finally=False)


    def check_some_conditions():
        # Replace with actual health check logic
        # Return False if something is wrong
        try:
            sys.stdout.flush()
            sys.stderr.flush()
            return True
        except BaseException:
            # to catch case when hit I/O operation on closed file, from some unknown non-python package
            traceback.print_exc()
            return False