File size: 2,461 Bytes
9bf16b8
9228783
 
 
 
 
 
 
 
 
 
00a0836
9bf16b8
82f910d
9228783
 
1b8ba0a
9228783
9bf16b8
9228783
 
 
 
 
 
 
2a63a46
5256208
 
 
 
 
02bbfd9
 
9bf16b8
9228783
 
2a63a46
9228783
9bf16b8
9228783
 
 
 
 
 
 
 
 
 
 
 
 
992a3b7
9228783
 
 
 
 
 
 
 
 
 
82f910d
9228783
 
 
 
 
 
 
 
 
9bf16b8
9228783
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
from PIL import Image

model_id = "mattraj/curacel-autodamage-1"
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).eval().to(device)
processor = PaliGemmaProcessor.from_pretrained(model_id)

###### Transformers Inference
@spaces.GPU
def infer(
        image: PIL.Image.Image,
        text: str,
        max_new_tokens: int
) -> str:
    inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_length=2048
        )
    result = processor.decode(generated_ids[0], skip_special_tokens=True)
    return result

######## Demo

INTRO_TEXT = """## Curacel Auto Damage demo\n\n
Finetuned from: google/paligemma-3b-pt-448
"""

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(INTRO_TEXT)
    with gr.Tab("Text Generation"):
        with gr.Column():
            image = gr.Image(type="pil")
            text_input = gr.Text(label="Input Text")

            text_output = gr.Text(label="Text Output")
            chat_btn = gr.Button()

        chat_inputs = [
            image,
            text_input
        ]
        chat_outputs = [
            text_output
        ]
        chat_btn.click(
            fn=infer,
            inputs=chat_inputs,
            outputs=chat_outputs,
        )

        examples = [["./car-1.png", "detect Front-Windscreen-Damage ; Headlight-Damage ; Major-Rear-Bumper-Dent ; Rear-windscreen-Damage ; RunningBoard-Dent ; Sidemirror-Damage ; Signlight-Damage ; Taillight-Damage ; bonnet-dent ; doorouter-dent ; doorouter-scratch ; fender-dent ; front-bumper-dent ; front-bumper-scratch ; medium-Bodypanel-Dent ; paint-chip ; paint-trace ; pillar-dent ; quaterpanel-dent ; rear-bumper-dent ; rear-bumper-scratch ; roof-dent"]]
        gr.Markdown("")

        gr.Examples(
            examples=examples,
            inputs=chat_inputs,
        )

#########

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)