brestok commited on
Commit
3061962
·
1 Parent(s): b28ce62
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.index filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ env/
3
+ venv/
4
+ .venv/
5
+ .idea/
6
+ *.log
7
+ *.egg-info/
8
+ pip-wheel-EntityData/
9
+ .env
10
+ .DS_Store
11
+ static/
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12.7
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ USER root
13
+ RUN apt-get update && apt-get install -y poppler-utils
14
+ USER user
15
+
16
+ COPY --chown=user . /app
17
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ocr 2
3
+ emoji: 😻
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ ---
main.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ocr import create_app
2
+
3
+ app = create_app()
ocr/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from starlette.exceptions import HTTPException as StarletteHTTPException
6
+ from starlette.staticfiles import StaticFiles
7
+
8
+ from ocr.core.config import settings
9
+ from ocr.core.wrappers import OcrResponseWrapper, ErrorOcrResponse
10
+
11
+
12
+ def create_app() -> FastAPI:
13
+ app = FastAPI()
14
+
15
+ from ocr.api.message import report_router
16
+ app.include_router(report_router, tags=['message'])
17
+
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ static_directory = os.path.join(settings.BASE_DIR, 'static')
26
+ if not os.path.exists(static_directory):
27
+ os.makedirs(static_directory)
28
+
29
+ app.mount(
30
+ '/static',
31
+ StaticFiles(directory='static'),
32
+ )
33
+
34
+ @app.exception_handler(StarletteHTTPException)
35
+ async def http_exception_handler(_, exc):
36
+ return OcrResponseWrapper(
37
+ data=None,
38
+ successful=False,
39
+ error=ErrorOcrResponse(message=str(exc.detail))
40
+ ).response(exc.status_code)
41
+
42
+ @app.get("/")
43
+ async def read_root():
44
+ return {"message": "Hello world!"}
45
+
46
+ return app
ocr/api/message/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi.routing import APIRouter
2
+
3
+ report_router = APIRouter(
4
+ prefix="/api/report", tags=["message"]
5
+ )
6
+
7
+ from . import views
ocr/api/message/ai/openai_request.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ocr.api.message.ai.prompts import OCRPrompts
2
+ from ocr.core.wrappers import openai_wrapper
3
+
4
+
5
+ @openai_wrapper(model='gpt-4o-mini')
6
+ async def generate_report(request_content: list[dict]):
7
+ messages = [
8
+ {
9
+ "role": "system",
10
+ "content": OCRPrompts.generate_general_answer
11
+ },
12
+ {
13
+ "role": "user",
14
+ "content": request_content
15
+ }
16
+ ]
17
+ return messages
ocr/api/message/ai/prompts.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class OCRPrompts:
2
+ generate_general_answer = """## Task
3
+
4
+ You must analyze the text extracted from medical document and generate a comprehensive report in **Markdown2** format. Ensure that every detail provided in the document is included, and do not omit or modify any information. Your output must strictly follow the required format.
5
+
6
+ ## Report Structure
7
+
8
+ The report should be structured as follows, with each section containing only relevant information from the document:
9
+
10
+ ```markdown
11
+ ## Patient Information
12
+
13
+ - Name: [Patient Name]
14
+ - Age: [Patient Age]
15
+ - Date of Scan: [Date]
16
+ - Indication: [Reason for the CT scan]
17
+
18
+ ## Findings
19
+
20
+ **Primary findings**:
21
+ [Describe significant abnormalities or findings relevant to the indication]
22
+
23
+ ** Secondary findings**:
24
+ [List incidental findings, e.g., "Mild hepatic steatosis noted."]
25
+ **No abnormalities**:
26
+ [Mention organs or systems without abnormalities, e.g., "No evidence of lymphadenopathy or pleural effusion."]
27
+
28
+ ## Impression
29
+
30
+ [Summarize the findings concisely, e.g., "Findings suggest a primary lung tumor. Biopsy recommended for further evaluation."]
31
+
32
+ ## Recommendations
33
+
34
+ [Include next steps or further tests, e.g., "PET scan and consultation with oncology recommended."]
35
+ ```
36
+
37
+ [INST]
38
+
39
+ ## Instructions
40
+
41
+ - **Do not invent or infer any information.** Only use data provided in the user request.
42
+ - Ensure that the format is followed strictly, and the output is complete without any deviations.
43
+
44
+ [/INST]"""
ocr/api/message/db_requests.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from ocr.api.message.model import MessageModel
4
+ from ocr.core.config import settings
5
+
6
+
7
+ async def get_all_chat_messages_obj(page_size: int, page_index: int) -> tuple[list[MessageModel], int]:
8
+ skip = page_size * page_index
9
+ objects, total_count = await asyncio.gather(
10
+ settings.DB_CLIENT.messages
11
+ .find()
12
+ .skip(skip)
13
+ .limit(page_size)
14
+ .to_list(length=page_size),
15
+ settings.DB_CLIENT.messages.count_documents({})
16
+ )
17
+ return objects, total_count
18
+
19
+
20
+ async def save_report_obj(report: str, filename: str) -> MessageModel:
21
+ message = MessageModel(
22
+ text=report,
23
+ filename=filename,
24
+ )
25
+ await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
26
+ return message
ocr/api/message/dto.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class Paging(BaseModel):
5
+ pageSize: int
6
+ pageIndex: int
7
+ totalCount: int
ocr/api/message/model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from pydantic import Field
4
+
5
+ from ocr.core.database import MongoBaseModel
6
+
7
+
8
+ class MessageModel(MongoBaseModel):
9
+ text: str
10
+ filename: str
11
+ datetimeInserted: datetime = Field(default_factory=datetime.now)
12
+ datetimeUpdated: datetime = Field(default_factory=datetime.now)
ocr/api/message/schemas.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ from ocr.api.message.dto import Paging
4
+ from ocr.api.message.model import MessageModel
5
+ from ocr.core.wrappers import OcrResponseWrapper
6
+
7
+
8
+ class CreateMessageRequest(BaseModel):
9
+ text: str
10
+
11
+
12
+ class MessageWrapper(OcrResponseWrapper[MessageModel]):
13
+ pass
14
+
15
+
16
+ class AllMessageResponse(BaseModel):
17
+ paging: Paging
18
+ data: list[MessageModel]
19
+
20
+
21
+ class AllMessageWrapper(OcrResponseWrapper[AllMessageResponse]):
22
+ pass
ocr/api/message/views.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from fastapi import Query, UploadFile, File
4
+
5
+ from ocr.api.message import report_router
6
+ from ocr.api.message.ai.openai_request import generate_report
7
+ from ocr.api.message.db_requests import get_all_chat_messages_obj, save_report_obj
8
+ from ocr.api.message.dto import Paging
9
+ from ocr.api.message.model import MessageModel
10
+ from ocr.api.message.schemas import (AllMessageWrapper,
11
+ AllMessageResponse)
12
+ from ocr.api.message.utils import divide_images, prepare_request_content, clean_response
13
+ from ocr.core.wrappers import OcrResponseWrapper
14
+
15
+
16
+ @report_router.get('/all')
17
+ async def get_all_chat_messages(
18
+ pageSize: Optional[int] = Query(10, description="Number of countries to return per page"),
19
+ pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
20
+ ) -> AllMessageWrapper:
21
+ messages, _ = await get_all_chat_messages_obj(pageSize, pageIndex)
22
+ response = AllMessageResponse(
23
+ paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
24
+ data=messages
25
+ )
26
+ return AllMessageWrapper(data=response)
27
+
28
+
29
+ @report_router.post('/generate')
30
+ async def create_message(
31
+ file: UploadFile = File(...),
32
+ ) -> OcrResponseWrapper[MessageModel]:
33
+ try:
34
+ contents = await file.read()
35
+ images = divide_images(contents)
36
+ content = prepare_request_content(images)
37
+ report = await generate_report(content)
38
+ response = await save_report_obj(clean_response(report), file.filename)
39
+ return OcrResponseWrapper(data=response)
40
+ finally:
41
+ await file.close()
ocr/core/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ from functools import lru_cache
4
+
5
+ import motor.motor_asyncio
6
+ from dotenv import load_dotenv
7
+ from openai import AsyncClient
8
+
9
+ load_dotenv()
10
+
11
+ class BaseConfig:
12
+ BASE_DIR: pathlib.Path = pathlib.Path(__file__).parent.parent.parent
13
+ SECRET_KEY = os.getenv('SECRET')
14
+ OPENAI_CLIENT = AsyncClient(api_key=os.getenv('OPENAI_API_KEY'))
15
+ DB_CLIENT = motor.motor_asyncio.AsyncIOMotorClient(os.getenv("MONGO_DB_URL")).Ocr
16
+
17
+ class DevelopmentConfig(BaseConfig):
18
+ Issuer = "http://localhost:8000"
19
+ Audience = "http://localhost:3000"
20
+
21
+
22
+ class ProductionConfig(BaseConfig):
23
+ Issuer = ""
24
+ Audience = ""
25
+
26
+
27
+ @lru_cache()
28
+ def get_settings() -> DevelopmentConfig | ProductionConfig:
29
+ config_cls_dict = {
30
+ 'development': DevelopmentConfig,
31
+ 'production': ProductionConfig,
32
+ }
33
+ config_name = os.getenv('FASTAPI_CONFIG', default='development')
34
+ config_cls = config_cls_dict[config_name]
35
+ return config_cls()
36
+
37
+
38
+ settings = get_settings()
ocr/core/database.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from enum import Enum
3
+ from typing import Dict, Any, Type
4
+
5
+ from bson import ObjectId
6
+ from pydantic import GetCoreSchemaHandler, BaseModel, Field, AnyUrl
7
+ from pydantic.json_schema import JsonSchemaValue
8
+ from pydantic_core import core_schema
9
+
10
+
11
+ class PyObjectId:
12
+ @classmethod
13
+ def __get_pydantic_core_schema__(
14
+ cls, source: type, handler: GetCoreSchemaHandler
15
+ ) -> core_schema.CoreSchema:
16
+ return core_schema.with_info_after_validator_function(
17
+ cls.validate, core_schema.str_schema()
18
+ )
19
+
20
+ @classmethod
21
+ def __get_pydantic_json_schema__(
22
+ cls, schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler
23
+ ) -> JsonSchemaValue:
24
+ return {"type": "string"}
25
+
26
+ @classmethod
27
+ def validate(cls, value: str) -> ObjectId:
28
+ if not ObjectId.is_valid(value):
29
+ raise ValueError(f"Invalid ObjectId: {value}")
30
+ return ObjectId(value)
31
+
32
+ def __getattr__(self, item):
33
+ return getattr(self.__dict__['value'], item)
34
+
35
+ def __init__(self, value: str = None):
36
+ if value is None:
37
+ self.value = ObjectId()
38
+ else:
39
+ self.value = self.validate(value)
40
+
41
+ def __str__(self):
42
+ return str(self.value)
43
+
44
+
45
+ class MongoBaseModel(BaseModel):
46
+ id: str = Field(default_factory=lambda: str(PyObjectId()))
47
+
48
+ class Config:
49
+ arbitrary_types_allowed = True
50
+
51
+ def to_mongo(self) -> Dict[str, Any]:
52
+ def model_to_dict(model: BaseModel) -> Dict[str, Any]:
53
+ doc = {}
54
+ for name, value in model._iter():
55
+ key = model.__fields__[name].alias or name
56
+
57
+ if isinstance(value, BaseModel):
58
+ doc[key] = model_to_dict(value)
59
+ elif isinstance(value, list) and all(isinstance(i, BaseModel) for i in value):
60
+ doc[key] = [model_to_dict(item) for item in value]
61
+ elif value and isinstance(value, Enum):
62
+ doc[key] = value.value
63
+ elif isinstance(value, datetime):
64
+ doc[key] = value.isoformat()
65
+ elif value and isinstance(value, AnyUrl):
66
+ doc[key] = str(value)
67
+ else:
68
+ doc[key] = value
69
+
70
+ return doc
71
+
72
+ result = model_to_dict(self)
73
+ return result
74
+
75
+ @classmethod
76
+ def from_mongo(cls, data: Dict[str, Any]):
77
+ def restore_enums(inst: Any, model_cls: Type[BaseModel]) -> None:
78
+ for name, field in model_cls.__fields__.items():
79
+ value = getattr(inst, name)
80
+ if field and isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
81
+ setattr(inst, name, field.annotation(value))
82
+ elif isinstance(value, BaseModel):
83
+ restore_enums(value, value.__class__)
84
+ elif isinstance(value, list):
85
+ for i, item in enumerate(value):
86
+ if isinstance(item, BaseModel):
87
+ restore_enums(item, item.__class__)
88
+ elif isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
89
+ value[i] = field.annotation(item)
90
+ elif isinstance(value, dict):
91
+ for k, v in value.items():
92
+ if isinstance(v, BaseModel):
93
+ restore_enums(v, v.__class__)
94
+ elif isinstance(field.annotation, type) and issubclass(field.annotation, Enum):
95
+ value[k] = field.annotation(v)
96
+
97
+ if data is None:
98
+ return None
99
+ instance = cls(**data)
100
+ restore_enums(instance, instance.__class__)
101
+ return instance
ocr/core/wrappers.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import wraps
3
+ from typing import Generic, Optional, TypeVar
4
+
5
+ import pydash
6
+ from fastapi import HTTPException
7
+ from pydantic import BaseModel
8
+ from starlette.responses import JSONResponse
9
+
10
+ from ocr.core.config import settings
11
+
12
+ T = TypeVar('T')
13
+
14
+
15
+ class ErrorOcrResponse(BaseModel):
16
+ message: str
17
+
18
+
19
+ class OcrResponseWrapper(BaseModel, Generic[T]):
20
+ data: Optional[T] = None
21
+ successful: bool = True
22
+ error: Optional[ErrorOcrResponse] = None
23
+
24
+ def response(self, status_code: int):
25
+ return JSONResponse(
26
+ status_code=status_code,
27
+ content={
28
+ "data": self.data,
29
+ "successful": self.successful,
30
+ "error": self.error.dict() if self.error else None
31
+ }
32
+ )
33
+
34
+
35
+ def exception_wrapper(http_error: int, error_message: str):
36
+ def decorator(func):
37
+ @wraps(func)
38
+ async def wrapper(*args, **kwargs):
39
+ try:
40
+ return await func(*args, **kwargs)
41
+ except Exception as e:
42
+ raise HTTPException(status_code=http_error, detail=error_message) from e
43
+
44
+ return wrapper
45
+
46
+ return decorator
47
+
48
+
49
+ def openai_wrapper(
50
+ temperature: int | float = 0, model: str = "gpt-4o-mini", is_json: bool = False, return_: str = None
51
+ ):
52
+ def decorator(func):
53
+ @wraps(func)
54
+ async def wrapper(*args, **kwargs) -> str:
55
+ messages = await func(*args, **kwargs)
56
+ completion = await settings.OPENAI_CLIENT.chat.completions.create(
57
+ messages=messages,
58
+ temperature=temperature,
59
+ n=1,
60
+ model=model,
61
+ response_format={"type": "json_object"} if is_json else {"type": "text"}
62
+ )
63
+ response = completion.choices[0].message.content
64
+ if is_json:
65
+ response = json.loads(response)
66
+ if return_:
67
+ return pydash.get(response, return_)
68
+ return response
69
+
70
+ return wrapper
71
+
72
+ return decorator
73
+
74
+
75
+ def background_task():
76
+ def decorator(func):
77
+ @wraps(func)
78
+ async def wrapper(*args, **kwargs) -> str:
79
+ try:
80
+ result = await func(*args, **kwargs)
81
+ return result
82
+ except Exception as e:
83
+ pass
84
+
85
+ return wrapper
86
+
87
+ return decorator
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.8.0
3
+ certifi==2025.1.31
4
+ click==8.1.8
5
+ distro==1.9.0
6
+ dnspython==2.7.0
7
+ fastapi==0.115.8
8
+ h11==0.14.0
9
+ httpcore==1.0.7
10
+ httptools==0.6.4
11
+ httpx==0.28.1
12
+ idna==3.10
13
+ jiter==0.8.2
14
+ motor==3.7.0
15
+ openai==1.59.9
16
+ packaging==24.2
17
+ pdf2image==1.17.0
18
+ pillow==11.1.0
19
+ pydantic==2.10.6
20
+ pydantic_core==2.27.2
21
+ pymongo==4.11
22
+ pytesseract==0.3.13
23
+ python-dotenv==1.0.1
24
+ python-multipart==0.0.20
25
+ PyYAML==6.0.2
26
+ sniffio==1.3.1
27
+ starlette==0.45.3
28
+ tqdm==4.67.1
29
+ typing_extensions==4.12.2
30
+ uvicorn==0.34.0
31
+ uvloop==0.21.0
32
+ watchfiles==1.0.4
33
+ websockets==14.2