DeepDiveDev commited on
Commit
45b88db
·
verified ·
1 Parent(s): df3681e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import streamlit as st
3
  from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
4
  from PIL import Image
@@ -26,7 +25,7 @@ model = AutoModel.from_pretrained(
26
  use_safetensors=True,
27
  pad_token_id=tokenizer.eos_token_id
28
  )
29
- model = model.to(device)
30
  model = model.eval()
31
 
32
  # Load MarianMT translation model for Hindi to English translation
@@ -58,13 +57,13 @@ if image_file is not None:
58
 
59
  # Button to run OCR
60
  if st.button("Run OCR"):
61
- # Ensure model runs on CPU if GPU isn't available
62
- with torch.no_grad(): # Disable gradient calculations to save memory on CPU
63
- # Replace .cuda() with device handling based on CPU/GPU availability
64
- res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr', device=device)
65
 
66
  # Perform formatted text OCR
67
- res_format = model.chat(tokenizer, temp_file_path, ocr_type='format', device=device)
 
68
 
69
  # Use EasyOCR for both English and Hindi text recognition
70
  result_easyocr = reader.readtext(temp_file_path, detail=0)
@@ -84,9 +83,9 @@ if image_file is not None:
84
  st.subheader("Translated Hindi Text to English:")
85
  translated_text = []
86
  for sentence in result_easyocr:
87
- # Detect if the text is in Hindi (you can customize this based on text properties)
88
  if sentence: # Assuming non-empty text is translated
89
  tokenized_text = translation_tokenizer([sentence], return_tensors="pt", truncation=True)
 
90
  translation = translation_model.generate(**tokenized_text)
91
  translated_sentence = translation_tokenizer.decode(translation[0], skip_special_tokens=True)
92
  translated_text.append(translated_sentence)
@@ -94,12 +93,14 @@ if image_file is not None:
94
  st.write(" ".join(translated_text))
95
 
96
  # Additional OCR types using GOT-OCR2
97
- res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='', device=device)
 
98
  st.subheader("Fine-Grained OCR Results:")
99
  st.write(res_fine_grained)
100
 
101
  # Render formatted OCR to HTML
102
- res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html', device=device)
 
103
  st.subheader("Rendered OCR Results (HTML):")
104
  st.write(res_render)
105
 
@@ -114,3 +115,5 @@ if image_file is not None:
114
 
115
  # Clean up the temporary file after use
116
  os.remove(temp_file_path)
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
3
  from PIL import Image
 
25
  use_safetensors=True,
26
  pad_token_id=tokenizer.eos_token_id
27
  )
28
+ model = model.to(device) # Move the model to the correct device
29
  model = model.eval()
30
 
31
  # Load MarianMT translation model for Hindi to English translation
 
57
 
58
  # Button to run OCR
59
  if st.button("Run OCR"):
60
+ # Use GOT-OCR2 model for plain text OCR (structured documents)
61
+ with torch.no_grad():
62
+ res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr')
 
63
 
64
  # Perform formatted text OCR
65
+ with torch.no_grad():
66
+ res_format = model.chat(tokenizer, temp_file_path, ocr_type='format')
67
 
68
  # Use EasyOCR for both English and Hindi text recognition
69
  result_easyocr = reader.readtext(temp_file_path, detail=0)
 
83
  st.subheader("Translated Hindi Text to English:")
84
  translated_text = []
85
  for sentence in result_easyocr:
 
86
  if sentence: # Assuming non-empty text is translated
87
  tokenized_text = translation_tokenizer([sentence], return_tensors="pt", truncation=True)
88
+ tokenized_text = {key: val.to(device) for key, val in tokenized_text.items()} # Move tensors to device
89
  translation = translation_model.generate(**tokenized_text)
90
  translated_sentence = translation_tokenizer.decode(translation[0], skip_special_tokens=True)
91
  translated_text.append(translated_sentence)
 
93
  st.write(" ".join(translated_text))
94
 
95
  # Additional OCR types using GOT-OCR2
96
+ with torch.no_grad():
97
+ res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='')
98
  st.subheader("Fine-Grained OCR Results:")
99
  st.write(res_fine_grained)
100
 
101
  # Render formatted OCR to HTML
102
+ with torch.no_grad():
103
+ res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html')
104
  st.subheader("Rendered OCR Results (HTML):")
105
  st.write(res_render)
106
 
 
115
 
116
  # Clean up the temporary file after use
117
  os.remove(temp_file_path)
118
+
119
+ # Note: No need for if __name__ == "__main__": st.run()