Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,16 +22,40 @@ processor = PaliGemmaProcessor.from_pretrained(model_id)
|
|
22 |
def infer(
|
23 |
image: PIL.Image.Image,
|
24 |
text: str,
|
25 |
-
max_new_tokens: int
|
26 |
-
) ->
|
27 |
inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
|
28 |
with torch.no_grad():
|
29 |
generated_ids = model.generate(
|
30 |
**inputs,
|
31 |
-
max_length=
|
32 |
)
|
33 |
result = processor.decode(generated_ids[0], skip_special_tokens=True)
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
######## Demo
|
37 |
|
@@ -47,15 +71,12 @@ with gr.Blocks(css="style.css") as demo:
|
|
47 |
text_input = gr.Text(label="Input Text")
|
48 |
|
49 |
text_output = gr.Text(label="Text Output")
|
|
|
50 |
chat_btn = gr.Button()
|
51 |
|
52 |
-
chat_inputs = [
|
53 |
-
|
54 |
-
|
55 |
-
]
|
56 |
-
chat_outputs = [
|
57 |
-
text_output
|
58 |
-
]
|
59 |
chat_btn.click(
|
60 |
fn=infer,
|
61 |
inputs=chat_inputs,
|
|
|
22 |
def infer(
|
23 |
image: PIL.Image.Image,
|
24 |
text: str,
|
25 |
+
max_new_tokens: int = 2048
|
26 |
+
) -> tuple:
|
27 |
inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
|
28 |
with torch.no_grad():
|
29 |
generated_ids = model.generate(
|
30 |
**inputs,
|
31 |
+
max_length=max_new_tokens
|
32 |
)
|
33 |
result = processor.decode(generated_ids[0], skip_special_tokens=True)
|
34 |
+
|
35 |
+
# Placeholder to extract bounding box info from the result (you should replace this with actual bounding box extraction)
|
36 |
+
bounding_boxes = extract_bounding_boxes(result)
|
37 |
+
|
38 |
+
# Draw bounding boxes on the image
|
39 |
+
annotated_image = image.copy()
|
40 |
+
draw = ImageDraw.Draw(annotated_image)
|
41 |
+
|
42 |
+
# Example of drawing bounding boxes (replace with actual coordinates)
|
43 |
+
for idx, (box, label) in enumerate(bounding_boxes):
|
44 |
+
color = COLORS[idx % len(COLORS)]
|
45 |
+
draw.rectangle(box, outline=color, width=3)
|
46 |
+
draw.text((box[0], box[1]), label, fill=color)
|
47 |
+
|
48 |
+
return result, annotated_image
|
49 |
+
|
50 |
+
def extract_bounding_boxes(result):
|
51 |
+
"""
|
52 |
+
Extract bounding boxes and labels from the model result.
|
53 |
+
Placeholder logic - replace this with actual parsing logic from model output.
|
54 |
+
|
55 |
+
Example return: [((x1, y1, x2, y2), "Label")]
|
56 |
+
"""
|
57 |
+
# Example static bounding box and label
|
58 |
+
return [((50, 50, 200, 200), "Damage"), ((300, 300, 400, 400), "Dent")]
|
59 |
|
60 |
######## Demo
|
61 |
|
|
|
71 |
text_input = gr.Text(label="Input Text")
|
72 |
|
73 |
text_output = gr.Text(label="Text Output")
|
74 |
+
output_image = gr.Image(label="Annotated Image")
|
75 |
chat_btn = gr.Button()
|
76 |
|
77 |
+
chat_inputs = [image, text_input]
|
78 |
+
chat_outputs = [text_output, output_image]
|
79 |
+
|
|
|
|
|
|
|
|
|
80 |
chat_btn.click(
|
81 |
fn=infer,
|
82 |
inputs=chat_inputs,
|