sooooner commited on
Commit
68c034b
Β·
1 Parent(s): dcbff67
Files changed (2) hide show
  1. app.py +4 -9
  2. utils.py +18 -24
app.py CHANGED
@@ -6,14 +6,11 @@ import spaces
6
 
7
  from utils import Image2Text
8
 
9
- @spaces.GPU(duration=15)
10
  def greet(input_img):
11
  global image_to_text
12
- print('-----------')
13
- print(input_img[0])
14
- print('-----------')
15
- contents = image_to_text.get_text(input_img[0], num_beams=4)
16
- return '\n'.join(contents)
17
 
18
  examples_path = os.path.dirname(__file__)
19
 
@@ -26,15 +23,13 @@ if __name__ == "__main__":
26
 
27
  demo = gr.Interface(
28
  fn=greet,
29
- # inputs="image",
30
  inputs=gr.File(
31
  label="Drag (Select) 1 or more photos of your face",
32
  file_types=["image"],
33
  file_count="multiple"
34
  ),
35
- outputs="text",
36
  title=f"🍩 for Hwp math problems",
37
- # examples=[os.path.join(examples_path, "samples", img_name) for img_name in sorted(os.listdir("samples"))],
38
  cache_examples=True
39
  )
40
 
 
6
 
7
  from utils import Image2Text
8
 
9
+ @spaces.GPU(duration=30)
10
  def greet(input_img):
11
  global image_to_text
12
+ contents = image_to_text.get_text(input_img, num_beams=4)
13
+ return contents
 
 
 
14
 
15
  examples_path = os.path.dirname(__file__)
16
 
 
23
 
24
  demo = gr.Interface(
25
  fn=greet,
 
26
  inputs=gr.File(
27
  label="Drag (Select) 1 or more photos of your face",
28
  file_types=["image"],
29
  file_count="multiple"
30
  ),
31
+ outputs=gr.JSON(label="Extracted Texts"),
32
  title=f"🍩 for Hwp math problems",
 
33
  cache_examples=True
34
  )
35
 
utils.py CHANGED
@@ -1,10 +1,13 @@
1
  import os
2
- from typing import Union
3
-
 
 
4
  import PIL.Image
5
  import PIL.ImageOps
6
- import requests
7
-
 
8
 
9
  def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
10
  """
@@ -36,15 +39,6 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
36
  image = image.convert("RGB")
37
  return image
38
 
39
-
40
-
41
-
42
- import re
43
- import torch
44
- import numpy as np
45
- from PIL import Image
46
- from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
47
-
48
  def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height):
49
  width, height = image.size
50
  width_ratio = width / target_width
@@ -93,17 +87,17 @@ class Image2Text:
93
  return model, processor
94
 
95
  def load_img(self, inputs, width=480, height=480):
96
- # image = Image.fromarray(inputs)
97
- image = load_image(inputs)
98
- image = aspect_ratio_preserving_resize_and_crop(image, target_width=width, target_height=height).convert("RGB")
99
  img = self.processor(image , return_tensors="pt", size=(width, height)).pixel_values
100
- pixel_values = img.to(self.device)
 
101
  return pixel_values
102
 
103
  def generate(self, pixel_values, num_beams):
104
  outputs = self.model.generate(
105
  pixel_values,
106
- decoder_input_ids=self.decoder_input_ids,
107
  max_length=2048,
108
  early_stopping=True,
109
  pad_token_id=self.processor.tokenizer.pad_token_id,
@@ -116,12 +110,12 @@ class Image2Text:
116
  return outputs
117
 
118
  def postprocessing(self, outputs):
119
- seq = self.processor.batch_decode(outputs.sequences)[0]
120
- seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
121
- seq = re.sub(r"<.*?>", "", seq, count=1).strip()
122
- seq = self.processor.token2json(seq)
123
- contents = seq['content'].split('[newline]')
124
- return contents
125
 
126
  def get_text(self, img_path, num_beams=4):
127
  pixel_values = self.load_img(img_path)
 
1
  import os
2
+ import re
3
+ import torch
4
+ import requests
5
+ import numpy as np
6
  import PIL.Image
7
  import PIL.ImageOps
8
+ from PIL import Image
9
+ from typing import Union
10
+ from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
11
 
12
  def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
13
  """
 
39
  image = image.convert("RGB")
40
  return image
41
 
 
 
 
 
 
 
 
 
 
42
  def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height):
43
  width, height = image.size
44
  width_ratio = width / target_width
 
87
  return model, processor
88
 
89
  def load_img(self, inputs, width=480, height=480):
90
+ images = [load_image(input_) for input_ in inputs]
91
+ images = [aspect_ratio_preserving_resize_and_crop(image, target_width=width, target_height=height) for image in images]
 
92
  img = self.processor(image , return_tensors="pt", size=(width, height)).pixel_values
93
+ imgs = self.processor([image.convert("RGB") for image in images], return_tensors="pt", size=(width, height)).pixel_values
94
+ pixel_values = imgs.to(self.device)
95
  return pixel_values
96
 
97
  def generate(self, pixel_values, num_beams):
98
  outputs = self.model.generate(
99
  pixel_values,
100
+ decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1),
101
  max_length=2048,
102
  early_stopping=True,
103
  pad_token_id=self.processor.tokenizer.pad_token_id,
 
110
  return outputs
111
 
112
  def postprocessing(self, outputs):
113
+ seqs = self.processor.batch_decode(outputs.sequences)
114
+ seqs = [seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") for seq in seqs]
115
+ seqs = [re.sub(r"<.*?>", "", seq, count=1).strip() for seq in seqs]
116
+ seqs = [self.processor.token2json(seq) for seq in seqs]
117
+ contents = [seq['content'].split('[newline]') for seq in seqs]
118
+ return ['\n'.join(content) for content in contents]
119
 
120
  def get_text(self, img_path, num_beams=4):
121
  pixel_values = self.load_img(img_path)