waceke commited on
Commit
ccc2402
·
1 Parent(s): c2d58b3

Removed the sparrow key since it wasn't needed. Changed the model and processor to senga model and processor (#1)

Browse files

- Removed the sparrow key since it wasn't needed. Changed the model and processor to senga model and processor (391b6558c7e79b684f3e21c97063e16e97b5ef18)

Files changed (3) hide show
  1. config.py +4 -6
  2. inference.py +85 -0
  3. training.py +85 -0
config.py CHANGED
@@ -1,13 +1,11 @@
1
  from pydantic import BaseSettings
2
- import os
3
 
4
 
5
  class Settings(BaseSettings):
6
- huggingface_key: str = os.environ.get("huggingface_key")
7
- sparrow_key: str = os.environ.get("sparrow_key")
8
- processor: str = "katanaml-org/invoices-donut-model-v1"
9
- model: str = "katanaml-org/invoices-donut-model-v1"
10
- dataset: str = "katanaml-org/invoices-donut-data-v1"
11
  base_config: str = "naver-clova-ix/donut-base"
12
  base_processor: str = "naver-clova-ix/donut-base"
13
  base_model: str = "naver-clova-ix/donut-base"
 
1
  from pydantic import BaseSettings
 
2
 
3
 
4
  class Settings(BaseSettings):
5
+ huggingface_key: str = "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"
6
+ processor: str = "senga-ml/donut-training-v4"
7
+ model: str = "senga-ml/donut-training-v4"
8
+ dataset: str = "senga-ml/dnotes-data-v1"
 
9
  base_config: str = "naver-clova-ix/donut-base"
10
  base_processor: str = "naver-clova-ix/donut-base"
11
  base_model: str = "naver-clova-ix/donut-base"
inference.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, File, UploadFile, Form
2
+ from typing import Optional
3
+ from PIL import Image
4
+ import urllib.request
5
+ from io import BytesIO
6
+ from config import settings
7
+ import utils
8
+ import os
9
+ import json
10
+ from routers.donut_inference import process_document_donut
11
+
12
+
13
+ router = APIRouter()
14
+
15
+ def count_values(obj):
16
+ if isinstance(obj, dict):
17
+ count = 0
18
+ for value in obj.values():
19
+ count += count_values(value)
20
+ return count
21
+ elif isinstance(obj, list):
22
+ count = 0
23
+ for item in obj:
24
+ count += count_values(item)
25
+ return count
26
+ else:
27
+ return 1
28
+
29
+
30
+ @router.post("/inference")
31
+ async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
32
+ model_in_use: str = Form('donut')):
33
+
34
+ # if sparrow_key != settings.sparrow_key:
35
+ # return {"error": "Invalid Sparrow key."}
36
+
37
+ result = []
38
+ if file:
39
+ # Ensure the uploaded file is a JPG image
40
+ if file.content_type not in ["image/jpeg", "image/jpg"]:
41
+ return {"error": "Invalid file type. Only JPG images are allowed."}
42
+
43
+ image = Image.open(BytesIO(await file.read()))
44
+ processing_time = 0
45
+ if model_in_use == 'donut':
46
+ result, processing_time = process_document_donut(image)
47
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
48
+ print(f"Processing time inference: {processing_time:.2f} seconds")
49
+ elif image_url:
50
+ # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
51
+ with urllib.request.urlopen(image_url) as response:
52
+ content_type = response.info().get_content_type()
53
+ if content_type in ["image/jpeg", "image/jpg"]:
54
+ image = Image.open(BytesIO(response.read()))
55
+ else:
56
+ return {"error": "Invalid file type. Only JPG images are allowed."}
57
+
58
+ processing_time = 0
59
+ if model_in_use == 'donut':
60
+ result, processing_time = process_document_donut(image)
61
+ # parse file name from url
62
+ file_name = image_url.split("/")[-1]
63
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
64
+ print(f"Processing time inference: {processing_time:.2f} seconds")
65
+ else:
66
+ result = {"info": "No input provided"}
67
+
68
+ return result
69
+
70
+
71
+ @router.get("/statistics")
72
+ async def get_statistics():
73
+ file_path = settings.inference_stats_file
74
+
75
+ # Check if the file exists, and read its content
76
+ if os.path.exists(file_path):
77
+ with open(file_path, 'r') as file:
78
+ try:
79
+ content = json.load(file)
80
+ except json.JSONDecodeError:
81
+ content = []
82
+ else:
83
+ content = []
84
+
85
+ return content
training.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Form, BackgroundTasks
2
+ from config import settings
3
+ import os
4
+ import json
5
+ from routers.donut_evaluate import run_evaluate_donut
6
+ from routers.donut_training import run_training_donut
7
+ import utils
8
+
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ def invoke_training(max_epochs, val_check_interval, warmup_steps, model_in_use):
14
+ # if sparrow_key != settings.sparrow_key:
15
+ # return {"error": "Invalid Sparrow key."}
16
+
17
+ if model_in_use == 'donut':
18
+ processing_time = run_training_donut(max_epochs, val_check_interval, warmup_steps)
19
+ utils.log_stats(settings.training_stats_file, [processing_time, settings.model])
20
+ print(f"Processing time training: {processing_time:.2f} seconds")
21
+
22
+
23
+ @router.post("/training")
24
+ async def run_training(background_tasks: BackgroundTasks,
25
+ max_epochs: int = Form(30),
26
+ val_check_interval: float = Form(0.4),
27
+ warmup_steps: int = Form(81),
28
+ model_in_use: str = Form('donut')):
29
+
30
+ background_tasks.add_task(invoke_training, max_epochs, val_check_interval, warmup_steps, model_in_use)
31
+
32
+ return {"message": "Dnote Donut ML training started in the background"}
33
+
34
+
35
+ def invoke_evaluate(model_in_use):
36
+ # if sparrow_key != settings.sparrow_key:
37
+ # return {"error": "Invalid Sparrow key."}
38
+
39
+ if model_in_use == 'donut':
40
+ scores, accuracy, processing_time = run_evaluate_donut()
41
+ utils.log_stats(settings.evaluate_stats_file, [processing_time, scores, accuracy, settings.model])
42
+ print(f"Processing time evaluate: {processing_time:.2f} seconds")
43
+
44
+
45
+ @router.post("/evaluate")
46
+ async def run_evaluate(background_tasks: BackgroundTasks,
47
+ model_in_use: str = Form('donut')):
48
+
49
+ background_tasks.add_task(invoke_evaluate, model_in_use)
50
+
51
+ return {"message": "Dnote Donut ML model evaluation started in the background"}
52
+
53
+
54
+ @router.get("/statistics/training")
55
+ async def get_statistics_training():
56
+ file_path = settings.training_stats_file
57
+
58
+ # Check if the file exists, and read its content
59
+ if os.path.exists(file_path):
60
+ with open(file_path, 'r') as file:
61
+ try:
62
+ content = json.load(file)
63
+ except json.JSONDecodeError:
64
+ content = []
65
+ else:
66
+ content = []
67
+
68
+ return content
69
+
70
+
71
+ @router.get("/statistics/evaluate")
72
+ async def get_statistics_evaluate():
73
+ file_path = settings.evaluate_stats_file
74
+
75
+ # Check if the file exists, and read its content
76
+ if os.path.exists(file_path):
77
+ with open(file_path, 'r') as file:
78
+ try:
79
+ content = json.load(file)
80
+ except json.JSONDecodeError:
81
+ content = []
82
+ else:
83
+ content = []
84
+
85
+ return content