DeepDiveDev commited on
Commit
2653a83
·
verified ·
1 Parent(s): d1bb7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -15,7 +15,7 @@ model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwri
15
  # Function to extract text using both models
16
  def extract_text(image):
17
  try:
18
- # Ensure the input is a PIL image
19
  if isinstance(image, np.ndarray):
20
  if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
21
  image = np.stack([image] * 3, axis=-1)
@@ -23,17 +23,17 @@ def extract_text(image):
23
  else:
24
  image = Image.open(image).convert("RGB") # Ensure RGB mode
25
 
26
- # Resize for better accuracy
27
- image = image.resize((640, 640))
28
 
29
  # Process with the primary model
30
- pixel_values = processor1(images=image, return_tensors="pt").pixel_values
31
  generated_ids = model1.generate(pixel_values)
32
  extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
  # If output seems incorrect, use the fallback model
35
  if len(extracted_text.strip()) < 2:
36
- inputs = processor2(images=image, return_tensors="pt").pixel_values
37
  generated_ids = model2.generate(inputs)
38
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
 
@@ -51,4 +51,4 @@ iface = gr.Interface(
51
  description="Upload a handwritten document and get the extracted text.",
52
  )
53
 
54
- iface.launch()
 
15
  # Function to extract text using both models
16
  def extract_text(image):
17
  try:
18
+ # Convert NumPy array to PIL Image if needed
19
  if isinstance(image, np.ndarray):
20
  if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
21
  image = np.stack([image] * 3, axis=-1)
 
23
  else:
24
  image = Image.open(image).convert("RGB") # Ensure RGB mode
25
 
26
+ # Maintain aspect ratio while resizing
27
+ image.thumbnail((640, 640))
28
 
29
  # Process with the primary model
30
+ pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
31
  generated_ids = model1.generate(pixel_values)
32
  extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
  # If output seems incorrect, use the fallback model
35
  if len(extracted_text.strip()) < 2:
36
+ inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
37
  generated_ids = model2.generate(inputs)
38
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
 
 
51
  description="Upload a handwritten document and get the extracted text.",
52
  )
53
 
54
+ iface.launch()