init
Browse files- .gitattributes +1 -0
- .gitignore +11 -0
- Dockerfile +17 -0
- README.md +8 -0
- main.py +3 -0
- ocr/__init__.py +46 -0
- ocr/api/message/__init__.py +7 -0
- ocr/api/message/ai/openai_request.py +17 -0
- ocr/api/message/ai/prompts.py +44 -0
- ocr/api/message/db_requests.py +26 -0
- ocr/api/message/dto.py +7 -0
- ocr/api/message/model.py +12 -0
- ocr/api/message/schemas.py +22 -0
- ocr/api/message/views.py +41 -0
- ocr/core/config.py +38 -0
- ocr/core/database.py +101 -0
- ocr/core/wrappers.py +87 -0
- requirements.txt +33 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.index filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
env/
|
3 |
+
venv/
|
4 |
+
.venv/
|
5 |
+
.idea/
|
6 |
+
*.log
|
7 |
+
*.egg-info/
|
8 |
+
pip-wheel-EntityData/
|
9 |
+
.env
|
10 |
+
.DS_Store
|
11 |
+
static/
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.12.7
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
+
|
12 |
+
USER root
|
13 |
+
RUN apt-get update && apt-get install -y poppler-utils
|
14 |
+
USER user
|
15 |
+
|
16 |
+
COPY --chown=user . /app
|
17 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Ocr 2
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
main.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from ocr import create_app
|
2 |
+
|
3 |
+
app = create_app()
|
ocr/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
6 |
+
from starlette.staticfiles import StaticFiles
|
7 |
+
|
8 |
+
from ocr.core.config import settings
|
9 |
+
from ocr.core.wrappers import OcrResponseWrapper, ErrorOcrResponse
|
10 |
+
|
11 |
+
|
12 |
+
def create_app() -> FastAPI:
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
from ocr.api.message import report_router
|
16 |
+
app.include_router(report_router, tags=['message'])
|
17 |
+
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["*"],
|
21 |
+
allow_methods=["*"],
|
22 |
+
allow_headers=["*"],
|
23 |
+
)
|
24 |
+
|
25 |
+
static_directory = os.path.join(settings.BASE_DIR, 'static')
|
26 |
+
if not os.path.exists(static_directory):
|
27 |
+
os.makedirs(static_directory)
|
28 |
+
|
29 |
+
app.mount(
|
30 |
+
'/static',
|
31 |
+
StaticFiles(directory='static'),
|
32 |
+
)
|
33 |
+
|
34 |
+
@app.exception_handler(StarletteHTTPException)
|
35 |
+
async def http_exception_handler(_, exc):
|
36 |
+
return OcrResponseWrapper(
|
37 |
+
data=None,
|
38 |
+
successful=False,
|
39 |
+
error=ErrorOcrResponse(message=str(exc.detail))
|
40 |
+
).response(exc.status_code)
|
41 |
+
|
42 |
+
@app.get("/")
|
43 |
+
async def read_root():
|
44 |
+
return {"message": "Hello world!"}
|
45 |
+
|
46 |
+
return app
|
ocr/api/message/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.routing import APIRouter
|
2 |
+
|
3 |
+
report_router = APIRouter(
|
4 |
+
prefix="/api/report", tags=["message"]
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
ocr/api/message/ai/openai_request.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ocr.api.message.ai.prompts import OCRPrompts
|
2 |
+
from ocr.core.wrappers import openai_wrapper
|
3 |
+
|
4 |
+
|
5 |
+
@openai_wrapper(model='gpt-4o-mini')
|
6 |
+
async def generate_report(request_content: list[dict]):
|
7 |
+
messages = [
|
8 |
+
{
|
9 |
+
"role": "system",
|
10 |
+
"content": OCRPrompts.generate_general_answer
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"role": "user",
|
14 |
+
"content": request_content
|
15 |
+
}
|
16 |
+
]
|
17 |
+
return messages
|
ocr/api/message/ai/prompts.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class OCRPrompts:
|
2 |
+
generate_general_answer = """## Task
|
3 |
+
|
4 |
+
You must analyze the text extracted from medical document and generate a comprehensive report in **Markdown2** format. Ensure that every detail provided in the document is included, and do not omit or modify any information. Your output must strictly follow the required format.
|
5 |
+
|
6 |
+
## Report Structure
|
7 |
+
|
8 |
+
The report should be structured as follows, with each section containing only relevant information from the document:
|
9 |
+
|
10 |
+
```markdown
|
11 |
+
## Patient Information
|
12 |
+
|
13 |
+
- Name: [Patient Name]
|
14 |
+
- Age: [Patient Age]
|
15 |
+
- Date of Scan: [Date]
|
16 |
+
- Indication: [Reason for the CT scan]
|
17 |
+
|
18 |
+
## Findings
|
19 |
+
|
20 |
+
**Primary findings**:
|
21 |
+
[Describe significant abnormalities or findings relevant to the indication]
|
22 |
+
|
23 |
+
** Secondary findings**:
|
24 |
+
[List incidental findings, e.g., "Mild hepatic steatosis noted."]
|
25 |
+
**No abnormalities**:
|
26 |
+
[Mention organs or systems without abnormalities, e.g., "No evidence of lymphadenopathy or pleural effusion."]
|
27 |
+
|
28 |
+
## Impression
|
29 |
+
|
30 |
+
[Summarize the findings concisely, e.g., "Findings suggest a primary lung tumor. Biopsy recommended for further evaluation."]
|
31 |
+
|
32 |
+
## Recommendations
|
33 |
+
|
34 |
+
[Include next steps or further tests, e.g., "PET scan and consultation with oncology recommended."]
|
35 |
+
```
|
36 |
+
|
37 |
+
[INST]
|
38 |
+
|
39 |
+
## Instructions
|
40 |
+
|
41 |
+
- **Do not invent or infer any information.** Only use data provided in the user request.
|
42 |
+
- Ensure that the format is followed strictly, and the output is complete without any deviations.
|
43 |
+
|
44 |
+
[/INST]"""
|
ocr/api/message/db_requests.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from ocr.api.message.model import MessageModel
|
4 |
+
from ocr.core.config import settings
|
5 |
+
|
6 |
+
|
7 |
+
async def get_all_chat_messages_obj(page_size: int, page_index: int) -> tuple[list[MessageModel], int]:
|
8 |
+
skip = page_size * page_index
|
9 |
+
objects, total_count = await asyncio.gather(
|
10 |
+
settings.DB_CLIENT.messages
|
11 |
+
.find()
|
12 |
+
.skip(skip)
|
13 |
+
.limit(page_size)
|
14 |
+
.to_list(length=page_size),
|
15 |
+
settings.DB_CLIENT.messages.count_documents({})
|
16 |
+
)
|
17 |
+
return objects, total_count
|
18 |
+
|
19 |
+
|
20 |
+
async def save_report_obj(report: str, filename: str) -> MessageModel:
|
21 |
+
message = MessageModel(
|
22 |
+
text=report,
|
23 |
+
filename=filename,
|
24 |
+
)
|
25 |
+
await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
|
26 |
+
return message
|
ocr/api/message/dto.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class Paging(BaseModel):
|
5 |
+
pageSize: int
|
6 |
+
pageIndex: int
|
7 |
+
totalCount: int
|
ocr/api/message/model.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from pydantic import Field
|
4 |
+
|
5 |
+
from ocr.core.database import MongoBaseModel
|
6 |
+
|
7 |
+
|
8 |
+
class MessageModel(MongoBaseModel):
|
9 |
+
text: str
|
10 |
+
filename: str
|
11 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
12 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
ocr/api/message/schemas.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
from ocr.api.message.dto import Paging
|
4 |
+
from ocr.api.message.model import MessageModel
|
5 |
+
from ocr.core.wrappers import OcrResponseWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class CreateMessageRequest(BaseModel):
|
9 |
+
text: str
|
10 |
+
|
11 |
+
|
12 |
+
class MessageWrapper(OcrResponseWrapper[MessageModel]):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class AllMessageResponse(BaseModel):
|
17 |
+
paging: Paging
|
18 |
+
data: list[MessageModel]
|
19 |
+
|
20 |
+
|
21 |
+
class AllMessageWrapper(OcrResponseWrapper[AllMessageResponse]):
|
22 |
+
pass
|
ocr/api/message/views.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from fastapi import Query, UploadFile, File
|
4 |
+
|
5 |
+
from ocr.api.message import report_router
|
6 |
+
from ocr.api.message.ai.openai_request import generate_report
|
7 |
+
from ocr.api.message.db_requests import get_all_chat_messages_obj, save_report_obj
|
8 |
+
from ocr.api.message.dto import Paging
|
9 |
+
from ocr.api.message.model import MessageModel
|
10 |
+
from ocr.api.message.schemas import (AllMessageWrapper,
|
11 |
+
AllMessageResponse)
|
12 |
+
from ocr.api.message.utils import divide_images, prepare_request_content, clean_response
|
13 |
+
from ocr.core.wrappers import OcrResponseWrapper
|
14 |
+
|
15 |
+
|
16 |
+
@report_router.get('/all')
|
17 |
+
async def get_all_chat_messages(
|
18 |
+
pageSize: Optional[int] = Query(10, description="Number of countries to return per page"),
|
19 |
+
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
20 |
+
) -> AllMessageWrapper:
|
21 |
+
messages, _ = await get_all_chat_messages_obj(pageSize, pageIndex)
|
22 |
+
response = AllMessageResponse(
|
23 |
+
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
24 |
+
data=messages
|
25 |
+
)
|
26 |
+
return AllMessageWrapper(data=response)
|
27 |
+
|
28 |
+
|
29 |
+
@report_router.post('/generate')
|
30 |
+
async def create_message(
|
31 |
+
file: UploadFile = File(...),
|
32 |
+
) -> OcrResponseWrapper[MessageModel]:
|
33 |
+
try:
|
34 |
+
contents = await file.read()
|
35 |
+
images = divide_images(contents)
|
36 |
+
content = prepare_request_content(images)
|
37 |
+
report = await generate_report(content)
|
38 |
+
response = await save_report_obj(clean_response(report), file.filename)
|
39 |
+
return OcrResponseWrapper(data=response)
|
40 |
+
finally:
|
41 |
+
await file.close()
|
ocr/core/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
import motor.motor_asyncio
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from openai import AsyncClient
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
class BaseConfig:
|
12 |
+
BASE_DIR: pathlib.Path = pathlib.Path(__file__).parent.parent.parent
|
13 |
+
SECRET_KEY = os.getenv('SECRET')
|
14 |
+
OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
|
15 |
+
DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).Ocr
|
16 |
+
|
17 |
+
class DevelopmentConfig(BaseConfig):
|
18 |
+
Issuer = "http://localhost:8000"
|
19 |
+
Audience = "http://localhost:3000"
|
20 |
+
|
21 |
+
|
22 |
+
class ProductionConfig(BaseConfig):
|
23 |
+
Issuer = ""
|
24 |
+
Audience = ""
|
25 |
+
|
26 |
+
|
27 |
+
@lru_cache()
|
28 |
+
def get_settings() -> DevelopmentConfig | ProductionConfig:
|
29 |
+
config_cls_dict = {
|
30 |
+
'development': DevelopmentConfig,
|
31 |
+
'production': ProductionConfig,
|
32 |
+
}
|
33 |
+
config_name = os.getenv('FASTAPI_CONFIG', default='development')
|
34 |
+
config_cls = config_cls_dict[config_name]
|
35 |
+
return config_cls()
|
36 |
+
|
37 |
+
|
38 |
+
settings = get_settings()
|
ocr/core/database.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Dict, Any, Type
|
4 |
+
|
5 |
+
from bson import ObjectId
|
6 |
+
from pydantic import GetCoreSchemaHandler, BaseModel, Field, AnyUrl
|
7 |
+
from pydantic.json_schema import JsonSchemaValue
|
8 |
+
from pydantic_core import core_schema
|
9 |
+
|
10 |
+
|
11 |
+
class PyObjectId:
|
12 |
+
@classmethod
|
13 |
+
def __get_pydantic_core_schema__(
|
14 |
+
cls, source: type, handler: GetCoreSchemaHandler
|
15 |
+
) -> core_schema.CoreSchema:
|
16 |
+
return core_schema.with_info_after_validator_function(
|
17 |
+
cls.validate, core_schema.str_schema()
|
18 |
+
)
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def __get_pydantic_json_schema__(
|
22 |
+
cls, schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler
|
23 |
+
) -> JsonSchemaValue:
|
24 |
+
return {"type": "string"}
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def validate(cls, value: str) -> ObjectId:
|
28 |
+
if not ObjectId.is_valid(value):
|
29 |
+
raise ValueError(f"Invalid ObjectId: {value}")
|
30 |
+
return ObjectId(value)
|
31 |
+
|
32 |
+
def __getattr__(self, item):
|
33 |
+
return getattr(self.__dict__['value'], item)
|
34 |
+
|
35 |
+
def __init__(self, value: str = None):
|
36 |
+
if value is None:
|
37 |
+
self.value = ObjectId()
|
38 |
+
else:
|
39 |
+
self.value = self.validate(value)
|
40 |
+
|
41 |
+
def __str__(self):
|
42 |
+
return str(self.value)
|
43 |
+
|
44 |
+
|
45 |
+
class MongoBaseModel(BaseModel):
|
46 |
+
id: str = Field(default_factory=lambda: str(PyObjectId()))
|
47 |
+
|
48 |
+
class Config:
|
49 |
+
arbitrary_types_allowed = True
|
50 |
+
|
51 |
+
def to_mongo(self) -> Dict[str, Any]:
|
52 |
+
def model_to_dict(model: BaseModel) -> Dict[str, Any]:
|
53 |
+
doc = {}
|
54 |
+
for name, value in model._iter():
|
55 |
+
key = model.__fields__[name].alias or name
|
56 |
+
|
57 |
+
if isinstance(value, BaseModel):
|
58 |
+
doc[key] = model_to_dict(value)
|
59 |
+
elif isinstance(value, list) and all(isinstance(i, BaseModel) for i in value):
|
60 |
+
doc[key] = [model_to_dict(item) for item in value]
|
61 |
+
elif value and isinstance(value, Enum):
|
62 |
+
doc[key] = value.value
|
63 |
+
elif isinstance(value, datetime):
|
64 |
+
doc[key] = value.isoformat()
|
65 |
+
elif value and isinstance(value, AnyUrl):
|
66 |
+
doc[key] = str(value)
|
67 |
+
else:
|
68 |
+
doc[key] = value
|
69 |
+
|
70 |
+
return doc
|
71 |
+
|
72 |
+
result = model_to_dict(self)
|
73 |
+
return result
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def from_mongo(cls, data: Dict[str, Any]):
|
77 |
+
def restore_enums(inst: Any, model_cls: Type[BaseModel]) -> None:
|
78 |
+
for name, field in model_cls.__fields__.items():
|
79 |
+
value = getattr(inst, name)
|
80 |
+
if field and isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
|
81 |
+
setattr(inst, name, field.annotation(value))
|
82 |
+
elif isinstance(value, BaseModel):
|
83 |
+
restore_enums(value, value.__class__)
|
84 |
+
elif isinstance(value, list):
|
85 |
+
for i, item in enumerate(value):
|
86 |
+
if isinstance(item, BaseModel):
|
87 |
+
restore_enums(item, item.__class__)
|
88 |
+
elif isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
|
89 |
+
value[i] = field.annotation(item)
|
90 |
+
elif isinstance(value, dict):
|
91 |
+
for k, v in value.items():
|
92 |
+
if isinstance(v, BaseModel):
|
93 |
+
restore_enums(v, v.__class__)
|
94 |
+
elif isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
|
95 |
+
value[k] = field.annotation(v)
|
96 |
+
|
97 |
+
if data is None:
|
98 |
+
return None
|
99 |
+
instance = cls(**data)
|
100 |
+
restore_enums(instance, instance.__class__)
|
101 |
+
return instance
|
ocr/core/wrappers.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from functools import wraps
|
3 |
+
from typing import Generic, Optional, TypeVar
|
4 |
+
|
5 |
+
import pydash
|
6 |
+
from fastapi import HTTPException
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from starlette.responses import JSONResponse
|
9 |
+
|
10 |
+
from ocr.core.config import settings
|
11 |
+
|
12 |
+
T = TypeVar('T')
|
13 |
+
|
14 |
+
|
15 |
+
class ErrorOcrResponse(BaseModel):
|
16 |
+
message: str
|
17 |
+
|
18 |
+
|
19 |
+
class OcrResponseWrapper(BaseModel, Generic[T]):
|
20 |
+
data: Optional[T] = None
|
21 |
+
successful: bool = True
|
22 |
+
error: Optional[ErrorOcrResponse] = None
|
23 |
+
|
24 |
+
def response(self, status_code: int):
|
25 |
+
return JSONResponse(
|
26 |
+
status_code=status_code,
|
27 |
+
content={
|
28 |
+
"data": self.data,
|
29 |
+
"successful": self.successful,
|
30 |
+
"error": self.error.dict() if self.error else None
|
31 |
+
}
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def exception_wrapper(http_error: int, error_message: str):
|
36 |
+
def decorator(func):
|
37 |
+
@wraps(func)
|
38 |
+
async def wrapper(*args, **kwargs):
|
39 |
+
try:
|
40 |
+
return await func(*args, **kwargs)
|
41 |
+
except Exception as e:
|
42 |
+
raise HTTPException(status_code=http_error, detail=error_message) from e
|
43 |
+
|
44 |
+
return wrapper
|
45 |
+
|
46 |
+
return decorator
|
47 |
+
|
48 |
+
|
49 |
+
def openai_wrapper(
|
50 |
+
temperature: int | float = 0, model: str = "gpt-4o-mini", is_json: bool = False, return_: str = None
|
51 |
+
):
|
52 |
+
def decorator(func):
|
53 |
+
@wraps(func)
|
54 |
+
async def wrapper(*args, **kwargs) -> str:
|
55 |
+
messages = await func(*args, **kwargs)
|
56 |
+
completion = await settings.OPENAI_CLIENT.chat.completions.create(
|
57 |
+
messages=messages,
|
58 |
+
temperature=temperature,
|
59 |
+
n=1,
|
60 |
+
model=model,
|
61 |
+
response_format={"type": "json_object"} if is_json else {"type": "text"}
|
62 |
+
)
|
63 |
+
response = completion.choices[0].message.content
|
64 |
+
if is_json:
|
65 |
+
response = json.loads(response)
|
66 |
+
if return_:
|
67 |
+
return pydash.get(response, return_)
|
68 |
+
return response
|
69 |
+
|
70 |
+
return wrapper
|
71 |
+
|
72 |
+
return decorator
|
73 |
+
|
74 |
+
|
75 |
+
def background_task():
|
76 |
+
def decorator(func):
|
77 |
+
@wraps(func)
|
78 |
+
async def wrapper(*args, **kwargs) -> str:
|
79 |
+
try:
|
80 |
+
result = await func(*args, **kwargs)
|
81 |
+
return result
|
82 |
+
except Exception as e:
|
83 |
+
pass
|
84 |
+
|
85 |
+
return wrapper
|
86 |
+
|
87 |
+
return decorator
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.7.0
|
2 |
+
anyio==4.8.0
|
3 |
+
certifi==2025.1.31
|
4 |
+
click==8.1.8
|
5 |
+
distro==1.9.0
|
6 |
+
dnspython==2.7.0
|
7 |
+
fastapi==0.115.8
|
8 |
+
h11==0.14.0
|
9 |
+
httpcore==1.0.7
|
10 |
+
httptools==0.6.4
|
11 |
+
httpx==0.28.1
|
12 |
+
idna==3.10
|
13 |
+
jiter==0.8.2
|
14 |
+
motor==3.7.0
|
15 |
+
openai==1.59.9
|
16 |
+
packaging==24.2
|
17 |
+
pdf2image==1.17.0
|
18 |
+
pillow==11.1.0
|
19 |
+
pydantic==2.10.6
|
20 |
+
pydantic_core==2.27.2
|
21 |
+
pymongo==4.11
|
22 |
+
pytesseract==0.3.13
|
23 |
+
python-dotenv==1.0.1
|
24 |
+
python-multipart==0.0.20
|
25 |
+
PyYAML==6.0.2
|
26 |
+
sniffio==1.3.1
|
27 |
+
starlette==0.45.3
|
28 |
+
tqdm==4.67.1
|
29 |
+
typing_extensions==4.12.2
|
30 |
+
uvicorn==0.34.0
|
31 |
+
uvloop==0.21.0
|
32 |
+
watchfiles==1.0.4
|
33 |
+
websockets==14.2
|