Vinay15 commited on
Commit
dba283c
·
verified ·
1 Parent(s): c920662

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -2,37 +2,33 @@ import torch
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
  import gradio as gr
5
- import tempfile
6
 
7
- # Load the OCR model and tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
9
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
10
  low_cpu_mem_usage=True,
11
  pad_token_id=tokenizer.eos_token_id).eval()
12
 
13
- # Check if GPU is available and use it, else use CPU
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = model.to(device)
16
 
17
  # Function to perform OCR on the image
18
  def perform_ocr(image):
19
- # Save the image to a temporary file
20
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
21
- image.save(temp_file.name) # Save the image
22
- temp_image_path = temp_file.name # Get the file path for the saved image
23
 
24
- # Perform OCR using the model
25
- result = model.chat(tokenizer, temp_image_path, ocr_type='ocr')
 
 
 
 
26
  return result
27
 
28
- # Create the Gradio interface using the new syntax
29
- interface = gr.Interface(
30
- fn=perform_ocr,
31
- inputs=gr.Image(type="pil"), # Updated to gr.Image
32
- outputs=gr.Textbox(), # Updated to gr.Textbox
33
- title="OCR Web App",
34
- description="Upload an image to extract text using the GOT-OCR2.0 model."
35
- )
36
 
37
- # Launch the app
38
- interface.launch()
 
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
+ # Load the OCR model and tokenizer with low memory usage in mind
7
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
8
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
9
  low_cpu_mem_usage=True,
10
  pad_token_id=tokenizer.eos_token_id).eval()
11
 
12
+ # Ensure we are using CPU
13
+ device = torch.device('cpu')
14
  model = model.to(device)
15
 
16
  # Function to perform OCR on the image
17
  def perform_ocr(image):
18
+ # Open the image using PIL
19
+ pil_image = Image.open(image)
 
 
20
 
21
+ # Use torch.no_grad() to avoid unnecessary memory usage
22
+ with torch.no_grad():
23
+ # Perform OCR using the model (image passed as PIL image)
24
+ result = model.chat(tokenizer, pil_image, ocr_type='ocr')
25
+
26
+ # Return the extracted text
27
  return result
28
 
29
+ # Create the Gradio interface for file upload and OCR
30
+ iface = gr.Interface(fn=perform_ocr, inputs="file", outputs="text",
31
+ title="OCR Application", description="Upload an image to extract text.")
 
 
 
 
 
32
 
33
+ # Launch the Gradio app
34
+ iface.launch()