Vinay15 commited on
Commit
de0d96a
·
verified ·
1 Parent(s): c35e395

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -38
app.py CHANGED
@@ -1,51 +1,26 @@
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
- import torch
5
 
6
- # Load the tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
8
-
9
- # Try loading the model with error handling
10
- try:
11
- model = AutoModel.from_pretrained(
12
- 'ucaslcl/GOT-OCR2_0',
13
- trust_remote_code=True,
14
- low_cpu_mem_usage=True,
15
- device_map='auto', # Use 'auto' to decide whether to use CPU or GPU
16
- use_safetensors=True,
17
- pad_token_id=tokenizer.eos_token_id
18
- )
19
-
20
- # Check if CUDA (GPU) is available, else fall back to CPU
21
- if torch.cuda.is_available():
22
- model = model.eval().cuda()
23
- print("Model loaded on GPU.")
24
- else:
25
- model = model.eval().cpu()
26
- print("CUDA not available, model loaded on CPU.")
27
-
28
- except Exception as e:
29
- print(f"Error loading model: {e}")
30
 
31
  # Define the OCR function
32
  def perform_ocr(image):
33
- try:
34
- # Convert PIL image to RGB format (if necessary)
35
- if image.mode != "RGB":
36
- image = image.convert("RGB")
37
 
38
- # Save the image to a temporary path
39
- image_file_path = 'temp_image.jpg'
40
- image.save(image_file_path)
41
 
42
- # Perform OCR using the model
43
- res = model.chat(tokenizer, image_file_path, ocr_type='ocr')
44
 
45
- return res
46
-
47
- except Exception as e:
48
- return str(e)
49
 
50
  # Define the Gradio interface
51
  interface = gr.Interface(
@@ -57,4 +32,4 @@ interface = gr.Interface(
57
  )
58
 
59
  # Launch the Gradio app
60
- interface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
 
4
 
5
+ # Load the tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
7
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
8
+ model = model.eval() # Remove .cuda() to run on CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Define the OCR function
11
  def perform_ocr(image):
12
+ # Convert PIL image to RGB format (if necessary)
13
+ if image.mode != "RGB":
14
+ image = image.convert("RGB")
 
15
 
16
+ # Save the image to a temporary path
17
+ image_file_path = 'temp_image.jpg'
18
+ image.save(image_file_path)
19
 
20
+ # Perform OCR using the model
21
+ res = model.chat(tokenizer, image_file_path, ocr_type='ocr')
22
 
23
+ return res
 
 
 
24
 
25
  # Define the Gradio interface
26
  interface = gr.Interface(
 
32
  )
33
 
34
  # Launch the Gradio app
35
+ interface.launch()