Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwri
|
|
15 |
# Function to extract text using both models
|
16 |
def extract_text(image):
|
17 |
try:
|
18 |
-
#
|
19 |
if isinstance(image, np.ndarray):
|
20 |
if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
|
21 |
image = np.stack([image] * 3, axis=-1)
|
@@ -23,17 +23,17 @@ def extract_text(image):
|
|
23 |
else:
|
24 |
image = Image.open(image).convert("RGB") # Ensure RGB mode
|
25 |
|
26 |
-
#
|
27 |
-
image
|
28 |
|
29 |
# Process with the primary model
|
30 |
-
pixel_values = processor1(images=image, return_tensors="pt").pixel_values
|
31 |
generated_ids = model1.generate(pixel_values)
|
32 |
extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
33 |
|
34 |
# If output seems incorrect, use the fallback model
|
35 |
if len(extracted_text.strip()) < 2:
|
36 |
-
inputs = processor2(images=image, return_tensors="pt").pixel_values
|
37 |
generated_ids = model2.generate(inputs)
|
38 |
extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
39 |
|
@@ -51,4 +51,4 @@ iface = gr.Interface(
|
|
51 |
description="Upload a handwritten document and get the extracted text.",
|
52 |
)
|
53 |
|
54 |
-
iface.launch()
|
|
|
15 |
# Function to extract text using both models
|
16 |
def extract_text(image):
|
17 |
try:
|
18 |
+
# Convert NumPy array to PIL Image if needed
|
19 |
if isinstance(image, np.ndarray):
|
20 |
if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
|
21 |
image = np.stack([image] * 3, axis=-1)
|
|
|
23 |
else:
|
24 |
image = Image.open(image).convert("RGB") # Ensure RGB mode
|
25 |
|
26 |
+
# Maintain aspect ratio while resizing
|
27 |
+
image.thumbnail((640, 640))
|
28 |
|
29 |
# Process with the primary model
|
30 |
+
pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
|
31 |
generated_ids = model1.generate(pixel_values)
|
32 |
extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
33 |
|
34 |
# If output seems incorrect, use the fallback model
|
35 |
if len(extracted_text.strip()) < 2:
|
36 |
+
inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
|
37 |
generated_ids = model2.generate(inputs)
|
38 |
extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
39 |
|
|
|
51 |
description="Upload a handwritten document and get the extracted text.",
|
52 |
)
|
53 |
|
54 |
+
iface.launch()
|