mattraj commited on
Commit
7da3754
·
verified ·
1 Parent(s): 2a63a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -11
app.py CHANGED
@@ -22,16 +22,40 @@ processor = PaliGemmaProcessor.from_pretrained(model_id)
22
  def infer(
23
  image: PIL.Image.Image,
24
  text: str,
25
- max_new_tokens: int
26
- ) -> str:
27
  inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
28
  with torch.no_grad():
29
  generated_ids = model.generate(
30
  **inputs,
31
- max_length=2048
32
  )
33
  result = processor.decode(generated_ids[0], skip_special_tokens=True)
34
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ######## Demo
37
 
@@ -47,15 +71,12 @@ with gr.Blocks(css="style.css") as demo:
47
  text_input = gr.Text(label="Input Text")
48
 
49
  text_output = gr.Text(label="Text Output")
 
50
  chat_btn = gr.Button()
51
 
52
- chat_inputs = [
53
- image,
54
- text_input
55
- ]
56
- chat_outputs = [
57
- text_output
58
- ]
59
  chat_btn.click(
60
  fn=infer,
61
  inputs=chat_inputs,
 
22
  def infer(
23
  image: PIL.Image.Image,
24
  text: str,
25
+ max_new_tokens: int = 2048
26
+ ) -> tuple:
27
  inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
28
  with torch.no_grad():
29
  generated_ids = model.generate(
30
  **inputs,
31
+ max_length=max_new_tokens
32
  )
33
  result = processor.decode(generated_ids[0], skip_special_tokens=True)
34
+
35
+ # Placeholder to extract bounding box info from the result (you should replace this with actual bounding box extraction)
36
+ bounding_boxes = extract_bounding_boxes(result)
37
+
38
+ # Draw bounding boxes on the image
39
+ annotated_image = image.copy()
40
+ draw = ImageDraw.Draw(annotated_image)
41
+
42
+ # Example of drawing bounding boxes (replace with actual coordinates)
43
+ for idx, (box, label) in enumerate(bounding_boxes):
44
+ color = COLORS[idx % len(COLORS)]
45
+ draw.rectangle(box, outline=color, width=3)
46
+ draw.text((box[0], box[1]), label, fill=color)
47
+
48
+ return result, annotated_image
49
+
50
+ def extract_bounding_boxes(result):
51
+ """
52
+ Extract bounding boxes and labels from the model result.
53
+ Placeholder logic - replace this with actual parsing logic from model output.
54
+
55
+ Example return: [((x1, y1, x2, y2), "Label")]
56
+ """
57
+ # Example static bounding box and label
58
+ return [((50, 50, 200, 200), "Damage"), ((300, 300, 400, 400), "Dent")]
59
 
60
  ######## Demo
61
 
 
71
  text_input = gr.Text(label="Input Text")
72
 
73
  text_output = gr.Text(label="Text Output")
74
+ output_image = gr.Image(label="Annotated Image")
75
  chat_btn = gr.Button()
76
 
77
+ chat_inputs = [image, text_input]
78
+ chat_outputs = [text_output, output_image]
79
+
 
 
 
 
80
  chat_btn.click(
81
  fn=infer,
82
  inputs=chat_inputs,