DeepDiveDev commited on
Commit
df3681e
·
verified ·
1 Parent(s): 4a2ef6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
3
  from PIL import Image
@@ -8,8 +9,8 @@ import re
8
  import torch
9
 
10
  # Check if GPU is available, else default to CPU
11
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
- st.write(f"Using device: {device.upper()}")
13
 
14
  # Load EasyOCR reader with English and Hindi language support
15
  reader = easyocr.Reader(['en', 'hi']) # 'en' for English, 'hi' for Hindi
@@ -17,17 +18,14 @@ reader = easyocr.Reader(['en', 'hi']) # 'en' for English, 'hi' for Hindi
17
  # Load the GOT-OCR2 model and tokenizer
18
  tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
19
 
20
- # Load the model with low memory usage on CPU or auto-map for GPU if available
21
  model = AutoModel.from_pretrained(
22
  'stepfun-ai/GOT-OCR2_0',
23
  trust_remote_code=True,
24
  low_cpu_mem_usage=True,
25
- device_map='auto' if device == 'cuda' else None, # Use GPU if available, else None
26
  use_safetensors=True,
27
  pad_token_id=tokenizer.eos_token_id
28
  )
29
-
30
- # Move model to appropriate device (GPU or CPU)
31
  model = model.to(device)
32
  model = model.eval()
33
 
@@ -61,14 +59,12 @@ if image_file is not None:
61
  # Button to run OCR
62
  if st.button("Run OCR"):
63
  # Ensure model runs on CPU if GPU isn't available
64
- if device == 'cuda':
65
- res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr')
66
- else:
67
- with torch.no_grad(): # Disable gradient calculations to save memory on CPU
68
- res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr')
69
 
70
  # Perform formatted text OCR
71
- res_format = model.chat(tokenizer, temp_file_path, ocr_type='format')
72
 
73
  # Use EasyOCR for both English and Hindi text recognition
74
  result_easyocr = reader.readtext(temp_file_path, detail=0)
@@ -98,12 +94,12 @@ if image_file is not None:
98
  st.write(" ".join(translated_text))
99
 
100
  # Additional OCR types using GOT-OCR2
101
- res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='')
102
  st.subheader("Fine-Grained OCR Results:")
103
  st.write(res_fine_grained)
104
 
105
  # Render formatted OCR to HTML
106
- res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html')
107
  st.subheader("Rendered OCR Results (HTML):")
108
  st.write(res_render)
109
 
@@ -118,6 +114,3 @@ if image_file is not None:
118
 
119
  # Clean up the temporary file after use
120
  os.remove(temp_file_path)
121
-
122
- # Note: No need for if __name__ == "__main__": st.run()
123
-
 
1
+
2
  import streamlit as st
3
  from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
4
  from PIL import Image
 
9
  import torch
10
 
11
  # Check if GPU is available, else default to CPU
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ st.write(f"Using device: {device}")
14
 
15
  # Load EasyOCR reader with English and Hindi language support
16
  reader = easyocr.Reader(['en', 'hi']) # 'en' for English, 'hi' for Hindi
 
18
  # Load the GOT-OCR2 model and tokenizer
19
  tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
20
 
21
+ # Load the model and move it to the correct device (GPU if available, else CPU)
22
  model = AutoModel.from_pretrained(
23
  'stepfun-ai/GOT-OCR2_0',
24
  trust_remote_code=True,
25
  low_cpu_mem_usage=True,
 
26
  use_safetensors=True,
27
  pad_token_id=tokenizer.eos_token_id
28
  )
 
 
29
  model = model.to(device)
30
  model = model.eval()
31
 
 
59
  # Button to run OCR
60
  if st.button("Run OCR"):
61
  # Ensure model runs on CPU if GPU isn't available
62
+ with torch.no_grad(): # Disable gradient calculations to save memory on CPU
63
+ # Replace .cuda() with device handling based on CPU/GPU availability
64
+ res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr', device=device)
 
 
65
 
66
  # Perform formatted text OCR
67
+ res_format = model.chat(tokenizer, temp_file_path, ocr_type='format', device=device)
68
 
69
  # Use EasyOCR for both English and Hindi text recognition
70
  result_easyocr = reader.readtext(temp_file_path, detail=0)
 
94
  st.write(" ".join(translated_text))
95
 
96
  # Additional OCR types using GOT-OCR2
97
+ res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='', device=device)
98
  st.subheader("Fine-Grained OCR Results:")
99
  st.write(res_fine_grained)
100
 
101
  # Render formatted OCR to HTML
102
+ res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html', device=device)
103
  st.subheader("Rendered OCR Results (HTML):")
104
  st.write(res_render)
105
 
 
114
 
115
  # Clean up the temporary file after use
116
  os.remove(temp_file_path)