dan-durbin commited on
Commit
0b1f1a9
·
1 Parent(s): 9aa5f5e

claude 3.5-assisted interface changes to allow for switching between markdown and OCR modes, parameter adjusting

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -17,11 +17,10 @@ model = AutoModelForVision2Seq.from_pretrained(
17
 
18
  processor = AutoProcessor.from_pretrained(repo)
19
 
20
- prompt = "<ocr>" # Options are '<ocr>' and '<md>'
21
-
22
 
23
  @spaces.GPU
24
- def process_image(image_path):
 
25
  image = Image.open(image_path)
26
  inputs = processor(text=prompt, images=image, return_tensors="pt")
27
 
@@ -33,17 +32,23 @@ def process_image(image_path):
33
  inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
34
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
35
 
36
- generated_ids = model.generate(**inputs, max_new_tokens=2048)
 
 
 
 
 
37
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
 
39
- return postprocess(generated_text, scale_height, scale_width, image)
40
 
41
 
42
- def postprocess(y, scale_height, scale_width, original_image):
 
43
  y = y.replace(prompt, "")
44
 
45
  if "<md>" in prompt:
46
- return y, original_image
47
 
48
  pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
49
  bboxs_raw = re.findall(pattern, y)
@@ -54,7 +59,6 @@ def postprocess(y, scale_height, scale_width, original_image):
54
 
55
  info = ""
56
 
57
- # Create a copy of the original image to draw on
58
  image_with_boxes = original_image.copy()
59
  draw = ImageDraw.Draw(image_with_boxes)
60
 
@@ -69,7 +73,6 @@ def postprocess(y, scale_height, scale_width, original_image):
69
  y1 = int(y1 * scale_height)
70
  info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n"
71
 
72
- # Draw rectangle on the image
73
  draw.rectangle([x0, y0, x1, y1], outline="red", width=2)
74
 
75
  return image_with_boxes, info
@@ -77,11 +80,21 @@ def postprocess(y, scale_height, scale_width, original_image):
77
 
78
  iface = gr.Interface(
79
  fn=process_image,
80
- inputs=gr.Image(type="filepath"),
 
 
 
 
 
 
81
  outputs=[
82
- gr.Image(type="pil", label="Image with Bounding Boxes"),
83
- gr.Textbox(label="Extracted Text"),
84
  ],
 
 
 
 
85
  )
86
 
87
  iface.launch()
 
17
 
18
  processor = AutoProcessor.from_pretrained(repo)
19
 
 
 
20
 
21
  @spaces.GPU
22
+ def process_image(image_path, task, num_beams, max_new_tokens, temperature):
23
+ prompt = "<ocr>" if task == "OCR" else "<md>"
24
  image = Image.open(image_path)
25
  inputs = processor(text=prompt, images=image, return_tensors="pt")
26
 
 
32
  inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
33
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
34
 
35
+ generated_ids = model.generate(
36
+ **inputs,
37
+ num_beams=num_beams,
38
+ max_new_tokens=max_new_tokens,
39
+ temperature=temperature,
40
+ )
41
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
42
 
43
+ return postprocess(generated_text, scale_height, scale_width, image, prompt)
44
 
45
 
46
+ @spaces.GPU
47
+ def postprocess(y, scale_height, scale_width, original_image, prompt):
48
  y = y.replace(prompt, "")
49
 
50
  if "<md>" in prompt:
51
+ return original_image, y
52
 
53
  pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
54
  bboxs_raw = re.findall(pattern, y)
 
59
 
60
  info = ""
61
 
 
62
  image_with_boxes = original_image.copy()
63
  draw = ImageDraw.Draw(image_with_boxes)
64
 
 
73
  y1 = int(y1 * scale_height)
74
  info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n"
75
 
 
76
  draw.rectangle([x0, y0, x1, y1], outline="red", width=2)
77
 
78
  return image_with_boxes, info
 
80
 
81
  iface = gr.Interface(
82
  fn=process_image,
83
+ inputs=[
84
+ gr.Image(type="filepath", label="Input Image"),
85
+ gr.Radio(["OCR", "Markdown"], label="Task", value="OCR"),
86
+ gr.Slider(1, 10, value=4, step=1, label="Number of Beams"),
87
+ gr.Slider(100, 4000, value=2048, step=100, label="Max New Tokens"),
88
+ gr.Slider(0.1, 1.0, value=1.0, step=0.1, label="Temperature"),
89
+ ],
90
  outputs=[
91
+ gr.Image(type="pil", label="Image with Bounding Boxes (OCR only)"),
92
+ gr.Textbox(label="Extracted Text / Markdown"),
93
  ],
94
+ title="Kosmos 2.5 OCR and Markdown Generator",
95
+ description="""Generate OCR results or Markdown from images using Kosmos 2.5.
96
+ Uses the Kosmos 2.5 [PR Branch] of the Transformers library for inference.
97
+ I don't know if the parameters do much of anything, but they're available for tweaking just in case.""",
98
  )
99
 
100
  iface.launch()