DeepDiveDev commited on
Commit
3a8de33
·
verified ·
1 Parent(s): 9b98135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
- import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
4
  import requests
5
 
6
  # Load your model from Hugging Face
@@ -9,7 +10,11 @@ model = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-oc
9
 
10
  # Function to extract text
11
  def extract_text(image):
12
- image = Image.open(image).convert("RGB")
 
 
 
 
13
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
14
  generated_ids = model.generate(pixel_values)
15
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
1
+ import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ import numpy as np
5
  import requests
6
 
7
  # Load your model from Hugging Face
 
10
 
11
  # Function to extract text
12
  def extract_text(image):
13
+ if isinstance(image, np.ndarray): # Check if input is a NumPy array
14
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
15
+ else:
16
+ image = Image.open(image).convert("RGB") # Open normally if not a NumPy array
17
+
18
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
19
  generated_ids = model.generate(pixel_values)
20
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]