Jangai commited on
Commit
cc24739
·
verified ·
1 Parent(s): 0f1409a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -1,49 +1,48 @@
1
  import gradio as gr
2
  import numpy as np
 
 
3
  from PIL import Image
 
4
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
 
6
- # Initialize the model and processor
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
8
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
9
 
10
  def display_sketch(sketch):
11
- if isinstance(sketch, dict) and 'composite' in sketch:
12
- image_data = sketch['composite']
13
- if isinstance(image_data, np.ndarray):
14
- img = Image.fromarray(image_data.astype('uint8'), 'RGBA')
15
- temp_file = "/home/user/app/output.png"
16
- img.save(temp_file)
17
- return temp_file
18
- return None
 
19
 
20
  def recognize_text(image_path):
21
- # Open the image
22
- image = Image.open(image_path)
23
- # Convert image to RGB
24
- image = image.convert("RGB")
 
 
25
 
26
- # Process the image
27
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
 
28
  generated_ids = model.generate(pixel_values)
29
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
30
-
31
  return generated_text
32
 
33
- # Define the Gradio interface
34
  with gr.Blocks() as demo:
35
- sketchpad = gr.Sketchpad(label="Draw Something")
36
- output_image = gr.Image(label="Your Sketch")
37
  recognized_text = gr.Textbox(label="Recognized Text")
38
 
39
- def process_and_recognize(sketch):
40
- image_path = display_sketch(sketch)
41
- if image_path:
42
- text = recognize_text(image_path)
43
- return image_path, text
44
- return None, ""
45
-
46
- sketchpad.change(process_and_recognize, inputs=sketchpad, outputs=[output_image, recognized_text])
47
 
48
- # Launch the demo
49
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import tempfile
5
  from PIL import Image
6
+ import torch
7
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
 
9
+ # Load model and processor
10
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
11
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
12
 
13
  def display_sketch(sketch):
14
+ image_data = sketch['composite']
15
+ plt.imshow(image_data)
16
+ plt.axis('off')
17
+
18
+ temp_file_path = "/mnt/data/output.png"
19
+ plt.savefig(temp_file_path, bbox_inches='tight', pad_inches=0)
20
+ plt.close()
21
+
22
+ return temp_file_path
23
 
24
  def recognize_text(image_path):
25
+ # Open image and convert to grayscale
26
+ image = Image.open(image_path).convert("L")
27
+ # Resize image to 256x256
28
+ image = image.resize((256, 256))
29
+ # Binarize image (convert to black and white)
30
+ image = image.point(lambda p: p > 128 and 255)
31
 
32
+ # Preprocess the image
33
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
34
+ # Generate prediction
35
  generated_ids = model.generate(pixel_values)
36
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
37
  return generated_text
38
 
 
39
  with gr.Blocks() as demo:
40
+ sketchpad = gr.Sketchpad(label="Draw Something", brush_radius=10)
41
+ sketchpad_output = gr.Image(label="Your Sketch")
42
  recognized_text = gr.Textbox(label="Recognized Text")
43
 
44
+ sketchpad.submit(display_sketch, inputs=sketchpad, outputs=sketchpad_output).then(
45
+ recognize_text, inputs=sketchpad_output, outputs=recognized_text
46
+ )
 
 
 
 
 
47
 
 
48
  demo.launch()