add features
Browse files- ocr/__init__.py +6 -3
- ocr/api/message/__init__.py +2 -2
- ocr/api/message/ai/openai_request.py +0 -17
- ocr/api/message/ai/prompts.py +0 -44
- ocr/api/message/db_requests.py +31 -18
- ocr/api/message/dto.py +5 -4
- ocr/api/message/models.py +14 -0
- ocr/api/message/schemas.py +2 -3
- ocr/api/message/views.py +18 -27
- ocr/api/openai_requests.py +0 -0
- ocr/api/prompts.py +0 -0
- ocr/api/report/__init__.py +7 -0
- ocr/api/report/db_requests.py +27 -0
- ocr/api/report/dto.py +16 -0
- ocr/api/{message → report}/model.py +4 -3
- ocr/api/report/schemas.py +8 -0
- ocr/api/report/views.py +41 -0
- ocr/api/{message/utils.py → utils.py} +14 -0
ocr/__init__.py
CHANGED
@@ -12,8 +12,11 @@ from ocr.core.wrappers import OcrResponseWrapper, ErrorOcrResponse
|
|
12 |
def create_app() -> FastAPI:
|
13 |
app = FastAPI()
|
14 |
|
15 |
-
from ocr.api.
|
16 |
-
app.include_router(report_router, tags=['
|
|
|
|
|
|
|
17 |
|
18 |
app.add_middleware(
|
19 |
CORSMiddleware,
|
@@ -41,6 +44,6 @@ def create_app() -> FastAPI:
|
|
41 |
|
42 |
@app.get("/")
|
43 |
async def read_root():
|
44 |
-
return {"
|
45 |
|
46 |
return app
|
|
|
12 |
def create_app() -> FastAPI:
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
from ocr.api.report import report_router
|
16 |
+
app.include_router(report_router, tags=['report'])
|
17 |
+
|
18 |
+
from ocr.api.message import message_router
|
19 |
+
app.include_router(message_router, tags=['message'])
|
20 |
|
21 |
app.add_middleware(
|
22 |
CORSMiddleware,
|
|
|
44 |
|
45 |
@app.get("/")
|
46 |
async def read_root():
|
47 |
+
return {"report": "Hello world!"}
|
48 |
|
49 |
return app
|
ocr/api/message/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from fastapi.routing import APIRouter
|
2 |
|
3 |
-
|
4 |
-
prefix="/api/
|
5 |
)
|
6 |
|
7 |
from . import views
|
|
|
1 |
from fastapi.routing import APIRouter
|
2 |
|
3 |
+
message_router = APIRouter(
|
4 |
+
prefix="/api/message", tags=["message"]
|
5 |
)
|
6 |
|
7 |
from . import views
|
ocr/api/message/ai/openai_request.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
from ocr.api.message.ai.prompts import OCRPrompts
|
2 |
-
from ocr.core.wrappers import openai_wrapper
|
3 |
-
|
4 |
-
|
5 |
-
@openai_wrapper(model='gpt-4o-mini')
|
6 |
-
async def generate_report(request_content: list[dict]):
|
7 |
-
messages = [
|
8 |
-
{
|
9 |
-
"role": "system",
|
10 |
-
"content": OCRPrompts.generate_general_answer
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"role": "user",
|
14 |
-
"content": request_content
|
15 |
-
}
|
16 |
-
]
|
17 |
-
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ocr/api/message/ai/prompts.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
class OCRPrompts:
|
2 |
-
generate_general_answer = """## Task
|
3 |
-
|
4 |
-
You must analyze the text extracted from medical document and generate a comprehensive report in **Markdown2** format. Ensure that every detail provided in the document is included, and do not omit or modify any information. Your output must strictly follow the required format.
|
5 |
-
|
6 |
-
## Report Structure
|
7 |
-
|
8 |
-
The report should be structured as follows, with each section containing only relevant information from the document:
|
9 |
-
|
10 |
-
```markdown
|
11 |
-
## Patient Information
|
12 |
-
|
13 |
-
- Name: [Patient Name]
|
14 |
-
- Age: [Patient Age]
|
15 |
-
- Date of Scan: [Date]
|
16 |
-
- Indication: [Reason for the CT scan]
|
17 |
-
|
18 |
-
## Findings
|
19 |
-
|
20 |
-
**Primary findings**:
|
21 |
-
[Describe significant abnormalities or findings relevant to the indication]
|
22 |
-
|
23 |
-
** Secondary findings**:
|
24 |
-
[List incidental findings, e.g., "Mild hepatic steatosis noted."]
|
25 |
-
**No abnormalities**:
|
26 |
-
[Mention organs or systems without abnormalities, e.g., "No evidence of lymphadenopathy or pleural effusion."]
|
27 |
-
|
28 |
-
## Impression
|
29 |
-
|
30 |
-
[Summarize the findings concisely, e.g., "Findings suggest a primary lung tumor. Biopsy recommended for further evaluation."]
|
31 |
-
|
32 |
-
## Recommendations
|
33 |
-
|
34 |
-
[Include next steps or further tests, e.g., "PET scan and consultation with oncology recommended."]
|
35 |
-
```
|
36 |
-
|
37 |
-
[INST]
|
38 |
-
|
39 |
-
## Instructions
|
40 |
-
|
41 |
-
- **Do not invent or infer any information.** Only use data provided in the user request.
|
42 |
-
- Ensure that the format is followed strictly, and the output is complete without any deviations.
|
43 |
-
|
44 |
-
[/INST]"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ocr/api/message/db_requests.py
CHANGED
@@ -1,26 +1,39 @@
|
|
1 |
import asyncio
|
2 |
|
3 |
-
from
|
4 |
-
from ocr.core.config import settings
|
5 |
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
async def get_all_chat_messages_obj(page_size: int, page_index: int) -> tuple[list[MessageModel], int]:
|
8 |
-
skip = page_size * page_index
|
9 |
-
objects, total_count = await asyncio.gather(
|
10 |
-
settings.DB_CLIENT.messages
|
11 |
-
.find()
|
12 |
-
.skip(skip)
|
13 |
-
.limit(page_size)
|
14 |
-
.to_list(length=page_size),
|
15 |
-
settings.DB_CLIENT.messages.count_documents({})
|
16 |
-
)
|
17 |
-
return objects, total_count
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
message = MessageModel(
|
22 |
-
text=report,
|
23 |
-
filename=filename,
|
24 |
-
)
|
25 |
await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
|
26 |
return message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
|
3 |
+
from fastapi import HTTPException
|
|
|
4 |
|
5 |
+
from ocr.api.message.dto import Author
|
6 |
+
from ocr.api.message.models import MessageModel
|
7 |
+
from ocr.api.message.schemas import CreateMessageRequest
|
8 |
+
from ocr.core.config import settings
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
async def create_message_obj(
|
12 |
+
report_id: str, message_data: CreateMessageRequest
|
13 |
+
) -> MessageModel:
|
14 |
+
report = await settings.DB_CLIENT.reports.find_one({"id": report_id})
|
15 |
+
if not report:
|
16 |
+
raise HTTPException(status_code=404, detail="Report not found.")
|
17 |
|
18 |
+
message = MessageModel(**message_data.model_dump(), reportId=report_id, author=Author.User)
|
|
|
|
|
|
|
|
|
19 |
await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
|
20 |
return message
|
21 |
+
|
22 |
+
|
23 |
+
async def get_all_chat_messages_obj(report_id: str) -> list[MessageModel]:
|
24 |
+
messages, report = await asyncio.gather(
|
25 |
+
settings.DB_CLIENT.messages.find({"reportId": report_id}).to_list(length=None),
|
26 |
+
settings.DB_CLIENT.reports.find_one({"id": report_id})
|
27 |
+
)
|
28 |
+
messages = [MessageModel.from_mongo(message) for message in messages]
|
29 |
+
if not report:
|
30 |
+
raise HTTPException(status_code=404, detail="Report not found")
|
31 |
+
return messages
|
32 |
+
|
33 |
+
|
34 |
+
async def save_assistant_user_message(user_message: str, assistant_message: str, report_id: str) -> MessageModel:
|
35 |
+
user_message = MessageModel(reportId=report_id, author=Author.User, text=user_message)
|
36 |
+
assistant_message = MessageModel(reportId=report_id, author=Author.Assistant, text=assistant_message)
|
37 |
+
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
38 |
+
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
39 |
+
return user_message
|
ocr/api/message/dto.py
CHANGED
@@ -1,7 +1,8 @@
|
|
|
|
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
|
4 |
-
class
|
5 |
-
|
6 |
-
|
7 |
-
totalCount: int
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
from pydantic import BaseModel
|
4 |
|
5 |
|
6 |
+
class Author(Enum):
|
7 |
+
User = "user"
|
8 |
+
Assistant = "assistant"
|
|
ocr/api/message/models.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from pydantic import Field
|
4 |
+
|
5 |
+
from ocr.api.message.dto import Author
|
6 |
+
from ocr.core.database import MongoBaseModel
|
7 |
+
|
8 |
+
|
9 |
+
class MessageModel(MongoBaseModel):
|
10 |
+
reportId: str
|
11 |
+
author: Author
|
12 |
+
text: str
|
13 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
ocr/api/message/schemas.py
CHANGED
@@ -1,14 +1,13 @@
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
-
from ocr.api.message.
|
4 |
-
from ocr.api.
|
5 |
from ocr.core.wrappers import OcrResponseWrapper
|
6 |
|
7 |
|
8 |
class CreateMessageRequest(BaseModel):
|
9 |
text: str
|
10 |
|
11 |
-
|
12 |
class MessageWrapper(OcrResponseWrapper[MessageModel]):
|
13 |
pass
|
14 |
|
|
|
1 |
from pydantic import BaseModel
|
2 |
|
3 |
+
from ocr.api.message.models import MessageModel
|
4 |
+
from ocr.api.report.dto import Paging
|
5 |
from ocr.core.wrappers import OcrResponseWrapper
|
6 |
|
7 |
|
8 |
class CreateMessageRequest(BaseModel):
|
9 |
text: str
|
10 |
|
|
|
11 |
class MessageWrapper(OcrResponseWrapper[MessageModel]):
|
12 |
pass
|
13 |
|
ocr/api/message/views.py
CHANGED
@@ -1,24 +1,17 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
from
|
4 |
-
|
5 |
-
from ocr.api.
|
6 |
-
from ocr.api.
|
7 |
-
from ocr.api.message.db_requests import get_all_chat_messages_obj, save_report_obj
|
8 |
-
from ocr.api.message.dto import Paging
|
9 |
-
from ocr.api.message.model import MessageModel
|
10 |
-
from ocr.api.message.schemas import (AllMessageWrapper,
|
11 |
-
AllMessageResponse)
|
12 |
-
from ocr.api.message.utils import divide_images, prepare_request_content, clean_response
|
13 |
from ocr.core.wrappers import OcrResponseWrapper
|
14 |
|
15 |
|
16 |
-
@
|
17 |
async def get_all_chat_messages(
|
18 |
-
|
19 |
-
pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
|
20 |
) -> AllMessageWrapper:
|
21 |
-
messages
|
22 |
response = AllMessageResponse(
|
23 |
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
24 |
data=messages
|
@@ -26,16 +19,14 @@ async def get_all_chat_messages(
|
|
26 |
return AllMessageWrapper(data=response)
|
27 |
|
28 |
|
29 |
-
@
|
30 |
async def create_message(
|
31 |
-
|
|
|
32 |
) -> OcrResponseWrapper[MessageModel]:
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
return OcrResponseWrapper(data=response)
|
40 |
-
finally:
|
41 |
-
await file.close()
|
|
|
1 |
+
from ocr.api.message import message_router
|
2 |
+
from ocr.api.message.db_requests import get_all_chat_messages_obj, save_assistant_user_message
|
3 |
+
from ocr.api.message.models import MessageModel
|
4 |
+
from ocr.api.message.schemas import AllMessageWrapper, AllMessageResponse, CreateMessageRequest
|
5 |
+
from ocr.api.report.dto import Paging
|
6 |
+
from ocr.api.utils import transform_messages_to_openai
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from ocr.core.wrappers import OcrResponseWrapper
|
8 |
|
9 |
|
10 |
+
@message_router.get('/{reportId}/all')
|
11 |
async def get_all_chat_messages(
|
12 |
+
reportId: str
|
|
|
13 |
) -> AllMessageWrapper:
|
14 |
+
messages = await get_all_chat_messages_obj(reportId)
|
15 |
response = AllMessageResponse(
|
16 |
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
|
17 |
data=messages
|
|
|
19 |
return AllMessageWrapper(data=response)
|
20 |
|
21 |
|
22 |
+
@message_router.post('/{reportId}')
|
23 |
async def create_message(
|
24 |
+
reportId: str,
|
25 |
+
message_data: CreateMessageRequest,
|
26 |
) -> OcrResponseWrapper[MessageModel]:
|
27 |
+
messages = await get_all_chat_messages_obj(reportId)
|
28 |
+
message_history = transform_messages_to_openai(messages)
|
29 |
+
# response = await generate_response()
|
30 |
+
response = 'Hello world'
|
31 |
+
response = await save_assistant_user_message(message_data.text, response, reportId)
|
32 |
+
return OcrResponseWrapper(data=response)
|
|
|
|
|
|
ocr/api/openai_requests.py
ADDED
File without changes
|
ocr/api/prompts.py
ADDED
File without changes
|
ocr/api/report/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.routing import APIRouter
|
2 |
+
|
3 |
+
report_router = APIRouter(
|
4 |
+
prefix="/api/report", tags=["report"]
|
5 |
+
)
|
6 |
+
|
7 |
+
from . import views
|
ocr/api/report/db_requests.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException
|
2 |
+
|
3 |
+
from ocr.api.report.dto import ReportModelShort
|
4 |
+
from ocr.api.report.model import ReportModel
|
5 |
+
from ocr.core.config import settings
|
6 |
+
|
7 |
+
|
8 |
+
async def get_all_reports_obj() -> list[ReportModelShort]:
|
9 |
+
reports = await settings.DB_CLIENT.reports.find({}).to_list(length=None)
|
10 |
+
return [ReportModelShort(**report) for report in reports]
|
11 |
+
|
12 |
+
|
13 |
+
async def delete_all_reports() -> None:
|
14 |
+
await settings.DB_CLIENT.reports.delete_many({})
|
15 |
+
|
16 |
+
|
17 |
+
async def get_report_obj_by_id(report_id: str) -> ReportModel:
|
18 |
+
report = await settings.DB_CLIENT.reports.find_one({"id": report_id})
|
19 |
+
if not report:
|
20 |
+
raise HTTPException(status_code=404, detail="Report not found")
|
21 |
+
return ReportModel.from_mongo(report)
|
22 |
+
|
23 |
+
|
24 |
+
async def save_report_obj(report: str, changes: str) -> ReportModel:
|
25 |
+
report = ReportModel(report=report, changes=changes, filename='maksim.docx')
|
26 |
+
await settings.DB_CLIENT.reports.insert_one(report.to_mongo())
|
27 |
+
return report
|
ocr/api/report/dto.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import ClassVar
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from ocr.api.report.model import ReportModel
|
6 |
+
|
7 |
+
|
8 |
+
class Paging(BaseModel):
|
9 |
+
pageSize: int
|
10 |
+
pageIndex: int
|
11 |
+
totalCount: int
|
12 |
+
|
13 |
+
|
14 |
+
class ReportModelShort(ReportModel):
|
15 |
+
report: ClassVar[str]
|
16 |
+
changes: ClassVar[str]
|
ocr/api/{message → report}/model.py
RENAMED
@@ -5,8 +5,9 @@ from pydantic import Field
|
|
5 |
from ocr.core.database import MongoBaseModel
|
6 |
|
7 |
|
8 |
-
class
|
9 |
-
|
|
|
10 |
filename: str
|
11 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
12 |
-
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
|
|
5 |
from ocr.core.database import MongoBaseModel
|
6 |
|
7 |
|
8 |
+
class ReportModel(MongoBaseModel):
|
9 |
+
report: str
|
10 |
+
changes: str
|
11 |
filename: str
|
12 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
13 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
ocr/api/report/schemas.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
from ocr.api.report.dto import Paging, ReportModelShort
|
4 |
+
|
5 |
+
|
6 |
+
class AllReportResponse(BaseModel):
|
7 |
+
paging: Paging
|
8 |
+
data: list[ReportModelShort]
|
ocr/api/report/views.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import UploadFile, File
|
2 |
+
|
3 |
+
from ocr.api.report import report_router
|
4 |
+
from ocr.api.report.db_requests import get_all_reports_obj, delete_all_reports, get_report_obj_by_id, save_report_obj
|
5 |
+
from ocr.api.report.dto import Paging
|
6 |
+
from ocr.api.report.model import ReportModel
|
7 |
+
from ocr.api.report.schemas import AllReportResponse
|
8 |
+
from ocr.core.wrappers import OcrResponseWrapper
|
9 |
+
|
10 |
+
|
11 |
+
@report_router.get('/all')
|
12 |
+
async def get_all_reports() -> OcrResponseWrapper[AllReportResponse]:
|
13 |
+
reports = await get_all_reports_obj()
|
14 |
+
response = AllReportResponse(
|
15 |
+
paging=Paging(pageSize=len(reports), pageIndex=0, totalCount=len(reports)),
|
16 |
+
data=reports
|
17 |
+
)
|
18 |
+
return OcrResponseWrapper(data=response)
|
19 |
+
|
20 |
+
|
21 |
+
@report_router.delete('/all')
|
22 |
+
async def delete_all_report() -> OcrResponseWrapper:
|
23 |
+
await delete_all_reports()
|
24 |
+
return OcrResponseWrapper()
|
25 |
+
|
26 |
+
|
27 |
+
@report_router.get('/{reportId}')
|
28 |
+
async def get_report(reportId: str) -> OcrResponseWrapper[ReportModel]:
|
29 |
+
report = await get_report_obj_by_id(reportId)
|
30 |
+
return OcrResponseWrapper(data=report)
|
31 |
+
|
32 |
+
|
33 |
+
@report_router.post('')
|
34 |
+
async def create_report(
|
35 |
+
file: UploadFile = File(...),
|
36 |
+
) -> OcrResponseWrapper[ReportModel]:
|
37 |
+
# messages = await create_new_reports(reportId)
|
38 |
+
# response = await generate_response(message_data.text, message_history)
|
39 |
+
report, changes = 'Hello', 'World'
|
40 |
+
report = await save_report_obj(report, changes)
|
41 |
+
return OcrResponseWrapper(data=report)
|
ocr/api/{message/utils.py → utils.py}
RENAMED
@@ -6,6 +6,20 @@ import pytesseract
|
|
6 |
from PIL import Image
|
7 |
from pdf2image import convert_from_bytes
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def divide_images(contents: bytes) -> list[bytes]:
|
11 |
images = convert_from_bytes(contents, dpi=250)
|
|
|
6 |
from PIL import Image
|
7 |
from pdf2image import convert_from_bytes
|
8 |
|
9 |
+
from ocr.api.message.models import MessageModel
|
10 |
+
|
11 |
+
|
12 |
+
def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
13 |
+
openai_messages = []
|
14 |
+
for message in messages:
|
15 |
+
content = message.text
|
16 |
+
openai_messages.append({
|
17 |
+
"role": message.author.value,
|
18 |
+
"content": content
|
19 |
+
})
|
20 |
+
|
21 |
+
return openai_messages
|
22 |
+
|
23 |
|
24 |
def divide_images(contents: bytes) -> list[bytes]:
|
25 |
images = convert_from_bytes(contents, dpi=250)
|