Jangai commited on
Commit
11b9b6a
·
verified ·
1 Parent(s): 2f5a61a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -45
app.py CHANGED
@@ -1,61 +1,46 @@
1
  import gradio as gr
2
  import numpy as np
 
 
3
  import matplotlib.pyplot as plt
4
  import tempfile
5
- import os
6
- import logging
7
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
- from PIL import Image
9
-
10
- logging.basicConfig(level=logging.DEBUG)
11
 
12
- # Initialize the TrOCR model and processor
13
- processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
14
- model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
15
 
16
  def display_sketch(sketch):
17
- logging.debug(f"Received sketch data: {sketch}")
18
-
19
  if isinstance(sketch, dict) and 'composite' in sketch:
20
- image_data = np.array(sketch['composite'], dtype=np.uint8)
21
- logging.debug(f"Image data type: {type(image_data)}")
22
- logging.debug(f"Image data shape: {image_data.shape}")
23
-
24
- # Ensure the image is in the correct format
25
- image = Image.fromarray(image_data, 'RGBA').convert('RGB')
26
-
27
- temp_file_path = os.path.join(os.getcwd(), "output.png")
28
- image.save(temp_file_path)
29
- logging.debug(f"Image saved to: {temp_file_path}")
30
-
31
- return temp_file_path
32
- else:
33
- error_message = f"Unexpected sketch data format: {type(sketch)}"
34
- logging.error(error_message)
35
- return error_message
36
 
37
  def recognize_text(image_path):
38
- try:
39
- # Load the image
40
- image = Image.open(image_path).convert("RGB")
41
- # Prepare the image for the model
42
- pixel_values = processor(image, return_tensors="pt").pixel_values
43
- # Generate the text
44
- generated_ids = model.generate(pixel_values)
45
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
46
- logging.debug(f"Recognized text: {generated_text}")
47
- return generated_text
48
- except Exception as e:
49
- logging.error(f"Error in recognizing text: {e}")
50
- return "Error in recognizing text"
51
-
52
  with gr.Blocks() as demo:
53
  sketchpad = gr.Sketchpad(label="Draw Something")
54
  output_image = gr.Image(label="Your Sketch")
55
  recognized_text = gr.Textbox(label="Recognized Text")
56
- submit_btn = gr.Button("Submit")
57
-
58
- submit_btn.click(fn=display_sketch, inputs=sketchpad, outputs=output_image)
59
- submit_btn.click(fn=recognize_text, inputs=output_image, outputs=recognized_text)
60
 
 
61
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ from PIL import Image
4
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
  import matplotlib.pyplot as plt
6
  import tempfile
 
 
 
 
 
 
7
 
8
+ # Initialize the model and processor
9
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
10
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
11
 
12
  def display_sketch(sketch):
 
 
13
  if isinstance(sketch, dict) and 'composite' in sketch:
14
+ image_data = sketch['composite']
15
+ if isinstance(image_data, np.ndarray):
16
+ img = Image.fromarray(image_data.astype('uint8'), 'RGBA')
17
+ temp_file = "/home/user/app/output.png"
18
+ img.save(temp_file)
19
+ return temp_file
20
+ return None
 
 
 
 
 
 
 
 
 
21
 
22
  def recognize_text(image_path):
23
+ # Open the image
24
+ image = Image.open(image_path)
25
+ # Convert image to RGB
26
+ image = image.convert("RGB")
27
+
28
+ # Process the image
29
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
30
+ generated_ids = model.generate(pixel_values)
31
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
32
+
33
+ return generated_text
34
+
35
+ # Define the Gradio interface
 
36
  with gr.Blocks() as demo:
37
  sketchpad = gr.Sketchpad(label="Draw Something")
38
  output_image = gr.Image(label="Your Sketch")
39
  recognized_text = gr.Textbox(label="Recognized Text")
40
+
41
+ sketchpad.submit(display_sketch, inputs=sketchpad, outputs=output_image).then(
42
+ recognize_text, inputs=output_image, outputs=recognized_text
43
+ )
44
 
45
+ # Launch the demo
46
  demo.launch()