import base64 import logging import os from datetime import datetime import torch from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from transformers import AutoModel, AutoTokenizer # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) app = FastAPI() # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 初始化模型 model_name = "Mageia/GOT-OCR2_0" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device) model = model.eval().to(device) # OCR处理函数 async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""): try: if "plain" in got_mode: if "multi-crop" in got_mode: res = model.chat_crop(tokenizer, image_path, ocr_type="ocr") else: res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color) return res elif "format" in got_mode: result_path = f"{os.path.splitext(image_path)[0]}_result.html" if "multi-crop" in got_mode: res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) else: res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path) if os.path.exists(result_path): with open(result_path, "r", encoding="utf-8") as f: html_content = f.read() encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8") return {"html_content": encoded_html} return {"error": "未知的OCR模式"} except Exception as e: return {"error": str(e)} @app.post("/ocr") async def ocr_api(request: Request, image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")): # 记录请求信息 client_host = request.client.host user_agent = request.headers.get("user-agent", "Unknown") current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") log_message = f""" 时间: {current_time} IP地址: {client_host} User-Agent: {user_agent} 图片名称: {image.filename} OCR模式: {got_mode} OCR颜色: {ocr_color} OCR边界框: {ocr_box} """ logger.info(log_message) # 保存上传的图片 image_path = f"temp_{image.filename}" with open(image_path, "wb") as buffer: buffer.write(await image.read()) # 处理OCR result = await ocr_process(image_path, got_mode, ocr_color, ocr_box) # 删除临时文件 os.remove(image_path) # 记录处理结果 logger.info(f"OCR处理结果: {result}") return result @app.get("/") async def read_root(request: Request): client_host = request.client.host user_agent = request.headers.get("user-agent", "Unknown") current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") log_message = f""" 时间: {current_time} IP地址: {client_host} User-Agent: {user_agent} 访问: 根路径 """ logger.info(log_message) return { "message": "欢迎使用OCR API", "user_agent": user_agent, "model": model_name, "device": device, "ocr_mode": [ "plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR", ], }