brestok commited on
Commit
ed5ec6a
·
1 Parent(s): 994916e

add features

Browse files
ocr/__init__.py CHANGED
@@ -12,8 +12,11 @@ from ocr.core.wrappers import OcrResponseWrapper, ErrorOcrResponse
12
  def create_app() -> FastAPI:
13
  app = FastAPI()
14
 
15
- from ocr.api.message import report_router
16
- app.include_router(report_router, tags=['message'])
 
 
 
17
 
18
  app.add_middleware(
19
  CORSMiddleware,
@@ -41,6 +44,6 @@ def create_app() -> FastAPI:
41
 
42
  @app.get("/")
43
  async def read_root():
44
- return {"message": "Hello world!"}
45
 
46
  return app
 
12
  def create_app() -> FastAPI:
13
  app = FastAPI()
14
 
15
+ from ocr.api.report import report_router
16
+ app.include_router(report_router, tags=['report'])
17
+
18
+ from ocr.api.message import message_router
19
+ app.include_router(message_router, tags=['message'])
20
 
21
  app.add_middleware(
22
  CORSMiddleware,
 
44
 
45
  @app.get("/")
46
  async def read_root():
47
+ return {"report": "Hello world!"}
48
 
49
  return app
ocr/api/message/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi.routing import APIRouter
2
 
3
- report_router = APIRouter(
4
- prefix="/api/report", tags=["message"]
5
  )
6
 
7
  from . import views
 
1
  from fastapi.routing import APIRouter
2
 
3
+ message_router = APIRouter(
4
+ prefix="/api/message", tags=["message"]
5
  )
6
 
7
  from . import views
ocr/api/message/ai/openai_request.py DELETED
@@ -1,17 +0,0 @@
1
- from ocr.api.message.ai.prompts import OCRPrompts
2
- from ocr.core.wrappers import openai_wrapper
3
-
4
-
5
- @openai_wrapper(model='gpt-4o-mini')
6
- async def generate_report(request_content: list[dict]):
7
- messages = [
8
- {
9
- "role": "system",
10
- "content": OCRPrompts.generate_general_answer
11
- },
12
- {
13
- "role": "user",
14
- "content": request_content
15
- }
16
- ]
17
- return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ocr/api/message/ai/prompts.py DELETED
@@ -1,44 +0,0 @@
1
- class OCRPrompts:
2
- generate_general_answer = """## Task
3
-
4
- You must analyze the text extracted from medical document and generate a comprehensive report in **Markdown2** format. Ensure that every detail provided in the document is included, and do not omit or modify any information. Your output must strictly follow the required format.
5
-
6
- ## Report Structure
7
-
8
- The report should be structured as follows, with each section containing only relevant information from the document:
9
-
10
- ```markdown
11
- ## Patient Information
12
-
13
- - Name: [Patient Name]
14
- - Age: [Patient Age]
15
- - Date of Scan: [Date]
16
- - Indication: [Reason for the CT scan]
17
-
18
- ## Findings
19
-
20
- **Primary findings**:
21
- [Describe significant abnormalities or findings relevant to the indication]
22
-
23
- ** Secondary findings**:
24
- [List incidental findings, e.g., "Mild hepatic steatosis noted."]
25
- **No abnormalities**:
26
- [Mention organs or systems without abnormalities, e.g., "No evidence of lymphadenopathy or pleural effusion."]
27
-
28
- ## Impression
29
-
30
- [Summarize the findings concisely, e.g., "Findings suggest a primary lung tumor. Biopsy recommended for further evaluation."]
31
-
32
- ## Recommendations
33
-
34
- [Include next steps or further tests, e.g., "PET scan and consultation with oncology recommended."]
35
- ```
36
-
37
- [INST]
38
-
39
- ## Instructions
40
-
41
- - **Do not invent or infer any information.** Only use data provided in the user request.
42
- - Ensure that the format is followed strictly, and the output is complete without any deviations.
43
-
44
- [/INST]"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ocr/api/message/db_requests.py CHANGED
@@ -1,26 +1,39 @@
1
  import asyncio
2
 
3
- from ocr.api.message.model import MessageModel
4
- from ocr.core.config import settings
5
 
 
 
 
 
6
 
7
- async def get_all_chat_messages_obj(page_size: int, page_index: int) -> tuple[list[MessageModel], int]:
8
- skip = page_size * page_index
9
- objects, total_count = await asyncio.gather(
10
- settings.DB_CLIENT.messages
11
- .find()
12
- .skip(skip)
13
- .limit(page_size)
14
- .to_list(length=page_size),
15
- settings.DB_CLIENT.messages.count_documents({})
16
- )
17
- return objects, total_count
18
 
 
 
 
 
 
 
19
 
20
- async def save_report_obj(report: str, filename: str) -> MessageModel:
21
- message = MessageModel(
22
- text=report,
23
- filename=filename,
24
- )
25
  await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
26
  return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
 
3
+ from fastapi import HTTPException
 
4
 
5
+ from ocr.api.message.dto import Author
6
+ from ocr.api.message.models import MessageModel
7
+ from ocr.api.message.schemas import CreateMessageRequest
8
+ from ocr.core.config import settings
9
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ async def create_message_obj(
12
+ report_id: str, message_data: CreateMessageRequest
13
+ ) -> MessageModel:
14
+ report = await settings.DB_CLIENT.reports.find_one({"id": report_id})
15
+ if not report:
16
+ raise HTTPException(status_code=404, detail="Report not found.")
17
 
18
+ message = MessageModel(**message_data.model_dump(), reportId=report_id, author=Author.User)
 
 
 
 
19
  await settings.DB_CLIENT.messages.insert_one(message.to_mongo())
20
  return message
21
+
22
+
23
+ async def get_all_chat_messages_obj(report_id: str) -> list[MessageModel]:
24
+ messages, report = await asyncio.gather(
25
+ settings.DB_CLIENT.messages.find({"reportId": report_id}).to_list(length=None),
26
+ settings.DB_CLIENT.reports.find_one({"id": report_id})
27
+ )
28
+ messages = [MessageModel.from_mongo(message) for message in messages]
29
+ if not report:
30
+ raise HTTPException(status_code=404, detail="Report not found")
31
+ return messages
32
+
33
+
34
+ async def save_assistant_user_message(user_message: str, assistant_message: str, report_id: str) -> MessageModel:
35
+ user_message = MessageModel(reportId=report_id, author=Author.User, text=user_message)
36
+ assistant_message = MessageModel(reportId=report_id, author=Author.Assistant, text=assistant_message)
37
+ await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
38
+ await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
39
+ return user_message
ocr/api/message/dto.py CHANGED
@@ -1,7 +1,8 @@
 
 
1
  from pydantic import BaseModel
2
 
3
 
4
- class Paging(BaseModel):
5
- pageSize: int
6
- pageIndex: int
7
- totalCount: int
 
1
+ from enum import Enum
2
+
3
  from pydantic import BaseModel
4
 
5
 
6
+ class Author(Enum):
7
+ User = "user"
8
+ Assistant = "assistant"
 
ocr/api/message/models.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from pydantic import Field
4
+
5
+ from ocr.api.message.dto import Author
6
+ from ocr.core.database import MongoBaseModel
7
+
8
+
9
+ class MessageModel(MongoBaseModel):
10
+ reportId: str
11
+ author: Author
12
+ text: str
13
+ datetimeInserted: datetime = Field(default_factory=datetime.now)
14
+ datetimeUpdated: datetime = Field(default_factory=datetime.now)
ocr/api/message/schemas.py CHANGED
@@ -1,14 +1,13 @@
1
  from pydantic import BaseModel
2
 
3
- from ocr.api.message.dto import Paging
4
- from ocr.api.message.model import MessageModel
5
  from ocr.core.wrappers import OcrResponseWrapper
6
 
7
 
8
  class CreateMessageRequest(BaseModel):
9
  text: str
10
 
11
-
12
  class MessageWrapper(OcrResponseWrapper[MessageModel]):
13
  pass
14
 
 
1
  from pydantic import BaseModel
2
 
3
+ from ocr.api.message.models import MessageModel
4
+ from ocr.api.report.dto import Paging
5
  from ocr.core.wrappers import OcrResponseWrapper
6
 
7
 
8
  class CreateMessageRequest(BaseModel):
9
  text: str
10
 
 
11
  class MessageWrapper(OcrResponseWrapper[MessageModel]):
12
  pass
13
 
ocr/api/message/views.py CHANGED
@@ -1,24 +1,17 @@
1
- from typing import Optional
2
-
3
- from fastapi import Query, UploadFile, File
4
-
5
- from ocr.api.message import report_router
6
- from ocr.api.message.ai.openai_request import generate_report
7
- from ocr.api.message.db_requests import get_all_chat_messages_obj, save_report_obj
8
- from ocr.api.message.dto import Paging
9
- from ocr.api.message.model import MessageModel
10
- from ocr.api.message.schemas import (AllMessageWrapper,
11
- AllMessageResponse)
12
- from ocr.api.message.utils import divide_images, prepare_request_content, clean_response
13
  from ocr.core.wrappers import OcrResponseWrapper
14
 
15
 
16
- @report_router.get('/all')
17
  async def get_all_chat_messages(
18
- pageSize: Optional[int] = Query(10, description="Number of countries to return per page"),
19
- pageIndex: Optional[int] = Query(0, description="Page index to retrieve"),
20
  ) -> AllMessageWrapper:
21
- messages, _ = await get_all_chat_messages_obj(pageSize, pageIndex)
22
  response = AllMessageResponse(
23
  paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
24
  data=messages
@@ -26,16 +19,14 @@ async def get_all_chat_messages(
26
  return AllMessageWrapper(data=response)
27
 
28
 
29
- @report_router.post('/generate')
30
  async def create_message(
31
- file: UploadFile = File(...),
 
32
  ) -> OcrResponseWrapper[MessageModel]:
33
- try:
34
- contents = await file.read()
35
- images = divide_images(contents)
36
- content = prepare_request_content(images)
37
- report = await generate_report(content)
38
- response = await save_report_obj(clean_response(report), file.filename)
39
- return OcrResponseWrapper(data=response)
40
- finally:
41
- await file.close()
 
1
+ from ocr.api.message import message_router
2
+ from ocr.api.message.db_requests import get_all_chat_messages_obj, save_assistant_user_message
3
+ from ocr.api.message.models import MessageModel
4
+ from ocr.api.message.schemas import AllMessageWrapper, AllMessageResponse, CreateMessageRequest
5
+ from ocr.api.report.dto import Paging
6
+ from ocr.api.utils import transform_messages_to_openai
 
 
 
 
 
 
7
  from ocr.core.wrappers import OcrResponseWrapper
8
 
9
 
10
+ @message_router.get('/{reportId}/all')
11
  async def get_all_chat_messages(
12
+ reportId: str
 
13
  ) -> AllMessageWrapper:
14
+ messages = await get_all_chat_messages_obj(reportId)
15
  response = AllMessageResponse(
16
  paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=len(messages)),
17
  data=messages
 
19
  return AllMessageWrapper(data=response)
20
 
21
 
22
+ @message_router.post('/{reportId}')
23
  async def create_message(
24
+ reportId: str,
25
+ message_data: CreateMessageRequest,
26
  ) -> OcrResponseWrapper[MessageModel]:
27
+ messages = await get_all_chat_messages_obj(reportId)
28
+ message_history = transform_messages_to_openai(messages)
29
+ # response = await generate_response()
30
+ response = 'Hello world'
31
+ response = await save_assistant_user_message(message_data.text, response, reportId)
32
+ return OcrResponseWrapper(data=response)
 
 
 
ocr/api/openai_requests.py ADDED
File without changes
ocr/api/prompts.py ADDED
File without changes
ocr/api/report/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi.routing import APIRouter
2
+
3
+ report_router = APIRouter(
4
+ prefix="/api/report", tags=["report"]
5
+ )
6
+
7
+ from . import views
ocr/api/report/db_requests.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+
3
+ from ocr.api.report.dto import ReportModelShort
4
+ from ocr.api.report.model import ReportModel
5
+ from ocr.core.config import settings
6
+
7
+
8
+ async def get_all_reports_obj() -> list[ReportModelShort]:
9
+ reports = await settings.DB_CLIENT.reports.find({}).to_list(length=None)
10
+ return [ReportModelShort(**report) for report in reports]
11
+
12
+
13
+ async def delete_all_reports() -> None:
14
+ await settings.DB_CLIENT.reports.delete_many({})
15
+
16
+
17
+ async def get_report_obj_by_id(report_id: str) -> ReportModel:
18
+ report = await settings.DB_CLIENT.reports.find_one({"id": report_id})
19
+ if not report:
20
+ raise HTTPException(status_code=404, detail="Report not found")
21
+ return ReportModel.from_mongo(report)
22
+
23
+
24
+ async def save_report_obj(report: str, changes: str) -> ReportModel:
25
+ report = ReportModel(report=report, changes=changes, filename='maksim.docx')
26
+ await settings.DB_CLIENT.reports.insert_one(report.to_mongo())
27
+ return report
ocr/api/report/dto.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from ocr.api.report.model import ReportModel
6
+
7
+
8
+ class Paging(BaseModel):
9
+ pageSize: int
10
+ pageIndex: int
11
+ totalCount: int
12
+
13
+
14
+ class ReportModelShort(ReportModel):
15
+ report: ClassVar[str]
16
+ changes: ClassVar[str]
ocr/api/{message → report}/model.py RENAMED
@@ -5,8 +5,9 @@ from pydantic import Field
5
  from ocr.core.database import MongoBaseModel
6
 
7
 
8
- class MessageModel(MongoBaseModel):
9
- text: str
 
10
  filename: str
11
  datetimeInserted: datetime = Field(default_factory=datetime.now)
12
- datetimeUpdated: datetime = Field(default_factory=datetime.now)
 
5
  from ocr.core.database import MongoBaseModel
6
 
7
 
8
+ class ReportModel(MongoBaseModel):
9
+ report: str
10
+ changes: str
11
  filename: str
12
  datetimeInserted: datetime = Field(default_factory=datetime.now)
13
+ datetimeUpdated: datetime = Field(default_factory=datetime.now)
ocr/api/report/schemas.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ from ocr.api.report.dto import Paging, ReportModelShort
4
+
5
+
6
+ class AllReportResponse(BaseModel):
7
+ paging: Paging
8
+ data: list[ReportModelShort]
ocr/api/report/views.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import UploadFile, File
2
+
3
+ from ocr.api.report import report_router
4
+ from ocr.api.report.db_requests import get_all_reports_obj, delete_all_reports, get_report_obj_by_id, save_report_obj
5
+ from ocr.api.report.dto import Paging
6
+ from ocr.api.report.model import ReportModel
7
+ from ocr.api.report.schemas import AllReportResponse
8
+ from ocr.core.wrappers import OcrResponseWrapper
9
+
10
+
11
+ @report_router.get('/all')
12
+ async def get_all_reports() -> OcrResponseWrapper[AllReportResponse]:
13
+ reports = await get_all_reports_obj()
14
+ response = AllReportResponse(
15
+ paging=Paging(pageSize=len(reports), pageIndex=0, totalCount=len(reports)),
16
+ data=reports
17
+ )
18
+ return OcrResponseWrapper(data=response)
19
+
20
+
21
+ @report_router.delete('/all')
22
+ async def delete_all_report() -> OcrResponseWrapper:
23
+ await delete_all_reports()
24
+ return OcrResponseWrapper()
25
+
26
+
27
+ @report_router.get('/{reportId}')
28
+ async def get_report(reportId: str) -> OcrResponseWrapper[ReportModel]:
29
+ report = await get_report_obj_by_id(reportId)
30
+ return OcrResponseWrapper(data=report)
31
+
32
+
33
+ @report_router.post('')
34
+ async def create_report(
35
+ file: UploadFile = File(...),
36
+ ) -> OcrResponseWrapper[ReportModel]:
37
+ # messages = await create_new_reports(reportId)
38
+ # response = await generate_response(message_data.text, message_history)
39
+ report, changes = 'Hello', 'World'
40
+ report = await save_report_obj(report, changes)
41
+ return OcrResponseWrapper(data=report)
ocr/api/{message/utils.py → utils.py} RENAMED
@@ -6,6 +6,20 @@ import pytesseract
6
  from PIL import Image
7
  from pdf2image import convert_from_bytes
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def divide_images(contents: bytes) -> list[bytes]:
11
  images = convert_from_bytes(contents, dpi=250)
 
6
  from PIL import Image
7
  from pdf2image import convert_from_bytes
8
 
9
+ from ocr.api.message.models import MessageModel
10
+
11
+
12
+ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
13
+ openai_messages = []
14
+ for message in messages:
15
+ content = message.text
16
+ openai_messages.append({
17
+ "role": message.author.value,
18
+ "content": content
19
+ })
20
+
21
+ return openai_messages
22
+
23
 
24
  def divide_images(contents: bytes) -> list[bytes]:
25
  images = convert_from_bytes(contents, dpi=250)