File size: 1,720 Bytes
415bf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaf214f
415bf3c
 
 
 
a17b121
415bf3c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from config import Settings
import torch
from PIL import Image
import io
from contextlib import asynccontextmanager
from transformers import VisionEncoderDecoderModel
from fastapi import FastAPI, UploadFile, Form, HTTPException
from transformers import TrOCRProcessor, AutoTokenizer, ViTImageProcessor

config = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = Settings()
    config['settings'] = settings
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(settings.TOKENIZER)
    feature_extractor = ViTImageProcessor.from_pretrained(settings.FEATURE_EXTRACTOR)
    config['processor'] = TrOCRProcessor(image_processor=feature_extractor, tokenizer=tokenizer)
    config['ocr_model'] = VisionEncoderDecoderModel.from_pretrained(settings.OCR_MODEL)
    
    yield
    # Clean up and release the resources
    config.clear()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def api_home():
    return {'detail': 'Welcome to Sinhala OCR Space'}

@app.post("/apply-trocr")
async def ApplyOCR(file: UploadFile):
    try:
        # Read the uploaded image file
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")

        pixel_values = config['processor'](image, return_tensors="pt").pixel_values
        generated_ids = config['ocr_model'].generate(pixel_values)
        generated_text =  config['processor'].batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Return the extracted text as the response
        return {"ocr_result": generated_text}
    except Exception as e:
        # Handle any exceptions that may occur
        return {"error": str(e)}