File size: 3,807 Bytes
9bf16b8
9228783
 
 
 
 
 
 
 
 
 
00a0836
9bf16b8
9228783
 
 
1b8ba0a
9228783
9bf16b8
9228783
 
 
 
 
 
 
 
 
 
 
 
9bf16b8
9228783
9bf16b8
9228783
 
 
9bf16b8
9228783
9bf16b8
 
9228783
 
 
 
 
 
 
9913257
5256208
 
 
 
 
02bbfd9
 
9bf16b8
9228783
 
 
 
 
 
 
 
 
 
 
 
9bf16b8
9228783
 
 
 
 
 
 
 
 
 
 
 
 
992a3b7
9228783
 
 
 
 
 
 
 
 
 
25c1606
5baff6e
25c1606
9228783
 
 
 
 
 
 
 
 
9bf16b8
9228783
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
from PIL import Image

model_id = "mattraj/curacel-transcription-1"
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).eval().to(device)
processor = PaliGemmaProcessor.from_pretrained(model_id)

def resize_and_pad(image, target_dim):
    # Calculate the aspect ratio
    scale_factor = 1
    aspect_ratio = image.width / image.height
    if aspect_ratio > 1:
        # Width is greater than height
        new_width = int(target_dim * scale_factor)
        new_height = int((target_dim / aspect_ratio) * scale_factor)
    else:
        # Height is greater than width
        new_height = int(target_dim * scale_factor)
        new_width = int(target_dim * aspect_ratio * scale_factor)

    resized_image = image.resize((new_width, new_height), Image.LANCZOS)

    # Create a new image with the target dimensions and a white background
    new_image = Image.new("RGB", (target_dim, target_dim), (255, 255, 255))
    new_image.paste(resized_image, ((target_dim - new_width) // 2, (target_dim - new_height) // 2))

    return new_image


###### Transformers Inference
@spaces.GPU
def infer(
        image: PIL.Image.Image,
        text: str,
        max_new_tokens: int
) -> str:
    inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_length=2048
        )
    result = processor.decode(generated_ids[0], skip_special_tokens=True)
    return result

######## Demo

INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n
Finetuned from: google/paligemma-3b-pt-448


Translation model demo at: https://prod.arabic-gpt.ai/

Prompts:
Translate the Arabic to English: {model output}

The following is a diagnosis in Arabic from a medical billing form we need to translate to English. The transcriber is not necessariily accurate so one or more characters or words may be wrong. Given what is written, what is the most likely diagnosis. Think step by step, and think about similar words or mispellings in Arabic. Give multiple arabic diagnoses along with the translation in English for each, then finally select the diagnosis that makes the most sense given what was transcribed and print the English translation as your most likely final translation. Transcribed text:  {model output}
"""

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(INTRO_TEXT)
    with gr.Tab("Text Generation"):
        with gr.Column():
            image = gr.Image(type="pil")
            text_input = gr.Text(label="Input Text")

            text_output = gr.Text(label="Text Output")
            chat_btn = gr.Button()

        chat_inputs = [
            image,
            text_input
        ]
        chat_outputs = [
            text_output
        ]
        chat_btn.click(
            fn=infer,
            inputs=chat_inputs,
            outputs=chat_outputs,
        )

        examples = [["./diagnosis-1.png", "Transcribe the Arabic text."],
                    ["./4800-13-diagnosis.png", "Transcribe the Arabic text."],
                    ["./sign.png", "Transcribe the Arabic text."]]
        gr.Markdown("")

        gr.Examples(
            examples=examples,
            inputs=chat_inputs,
        )

#########

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)