import asyncio import os import pickle import sys import json import inspect import threading import traceback import uuid from traceback import print_exception from pydantic import BaseModel from fastapi import FastAPI, Header, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi import Depends from fastapi.responses import JSONResponse, Response from fastapi_utils.tasks import repeat_every from starlette.responses import PlainTextResponse # Ensure required directories are in sys.path script_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(script_dir) if project_root not in sys.path: sys.path.append(project_root) if os.path.dirname('src') not in sys.path: sys.path.append('src') # similar to openai_server/server.py def verify_api_key(authorization: str = Header(None)) -> None: server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY') # print("server_api_key: %s %s" % (server_api_key, authorization)) if server_api_key == 'EMPTY': # dummy case since '' cannot be handled return if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"): raise HTTPException(status_code=401, detail="Unauthorized") app = FastAPI() check_key = [Depends(verify_api_key)] app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) class InvalidRequestError(Exception): pass class FunctionRequest(BaseModel): function_name: str args: tuple kwargs: dict use_disk: bool = False use_pickle: bool = False @app.get("/health") async def health() -> Response: """Health check.""" return Response(status_code=200) @app.exception_handler(Exception) async def validation_exception_handler(request, exc): print_exception(exc) exc2 = InvalidRequestError(str(exc)) return PlainTextResponse(str(exc2), status_code=400) @app.options("/", dependencies=check_key) async def options_route(): return JSONResponse(content="OK") gen_kwargs = {} gen_kwargs_lock = threading.Lock() def initialize_gen_kwargs(): global gen_kwargs with gen_kwargs_lock: # not strictly required if in global scope if not gen_kwargs: main_kwargs = json.loads(os.environ['H2OGPT_MAIN_KWARGS']) # required # don't double up LLMs, in pure "document ingest" mode main_kwargs['model_lock'] = [] main_kwargs['base_model'] = '' main_kwargs['inference_server'] = '' # only for chat part, not used here main_kwargs['enable_image'] = False main_kwargs['visible_image_models'] = [] main_kwargs['image_gpu_ids'] = None main_kwargs['enable_tts'] = False main_kwargs['enable_stt'] = False # function server mode only main_kwargs['gradio'] = False main_kwargs['eval'] = False main_kwargs['cli'] = False main_kwargs['function'] = True # don't double this main_kwargs['openai_server'] = False # FIXME: Deal with GPU IDs for each caption/ASR/DocTR model, use MIG, etc. from gen import main as gen_main gen_kwargs = gen_main(**main_kwargs) # Call the initialization function at startup, but not during import if 'H2OGPT_MAIN_KWARGS' in os.environ: initialize_gen_kwargs() else: print("H2OGPT_MAIN_KWARGS not found in os.environ") @app.post("/execute_function/", dependencies=check_key) def execute_function(request: FunctionRequest): # Mapping of function names to function objects from gpt_langchain import path_to_docs from vision.utils_vision import process_file_list FUNCTIONS = { 'path_to_docs': path_to_docs, 'process_file_list': process_file_list, } try: # Fetch the function from the function map func = FUNCTIONS.get(request.function_name) if not func: raise ValueError("Function not found") # use gen_kwargs if needed func_names = list(inspect.signature(func).parameters) func_kwargs = {k: v for k, v in gen_kwargs.items() if k in func_names and k not in request.kwargs} # Call the function with args and kwargs result = func(*request.args, **request.kwargs, **func_kwargs) if request.use_disk or request.use_pickle: # Save the result to a file on the shared disk base_path = 'function_results' if not os.path.isdir(base_path): os.makedirs(base_path) file_path = os.path.join(base_path, str(uuid.uuid4())) if request.use_pickle: file_path += '.pkl' with open(file_path, "wb") as f: pickle.dump(result, f) else: file_path += '.json' with open(file_path, "w") as f: json.dump(result, f) return {"status": "success", "file_path": os.path.abspath(file_path)} else: # Return the result directly return {"status": "success", "result": result} except Exception as e: traceback_str = ''.join(traceback.format_exception(e)) raise HTTPException(status_code=500, detail=traceback_str) finally: do_check(in_finally=True) def do_check(in_finally=False): health_result = check_some_conditions() if not health_result: print("Health check failed! Terminating without cleanup (to avoid races) %s..." % in_finally) if os.getenv('multiple_workers_gunicorn'): os._exit(1) state_checks = True if state_checks: @app.on_event("startup") async def startup_event(verbose=True): asyncio.create_task(periodic_health_check(verbose=verbose)) async def periodic_health_check(verbose=False): while True: if verbose: print("Checking health...") await asyncio.sleep(120) # Wait for 2 minutes between checks do_check(in_finally=False) def check_some_conditions(): # Replace with actual health check logic # Return False if something is wrong try: sys.stdout.flush() sys.stderr.flush() return True except BaseException: # to catch case when hit I/O operation on closed file, from some unknown non-python package traceback.print_exc() return False