Arch10 commited on
Commit
aca06b8
·
verified ·
1 Parent(s): a7b5cf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -1,16 +1,16 @@
1
- # Streamlit app for extracting text from an image using the General OCR Theory (GOT) 2.0 model
2
  import streamlit as st
3
  from transformers import AutoTokenizer, AutoModel
4
  import torch
5
  from PIL import Image
6
- import requests
7
 
8
  # Load the pre-trained GOT OCR 2.0 model and tokenizer
9
  @st.cache_resource(show_spinner=True)
10
  def load_model():
11
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
12
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
13
- return tokenizer, model.eval().cuda()
 
 
14
 
15
  # Streamlit interface
16
  st.title("OCR Application using General OCR Theory (GOT) 2.0")
@@ -24,17 +24,15 @@ if uploaded_file is not None:
24
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
25
 
26
  # Load model
27
- tokenizer, model = load_model()
28
 
29
- # Load the image into the model
30
- with open(uploaded_file.name, 'wb') as f:
31
- f.write(uploaded_file.getbuffer())
32
 
33
- image_file = uploaded_file.name
34
-
35
  # Perform OCR
36
  with st.spinner("Extracting text..."):
37
- res = model.chat(tokenizer, image_file, ocr_type='ocr')
38
 
39
  # Display the result
40
  st.write("Extracted Text:")
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
  from PIL import Image
 
5
 
6
  # Load the pre-trained GOT OCR 2.0 model and tokenizer
7
  @st.cache_resource(show_spinner=True)
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Check for GPU, fallback to CPU
11
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True)
12
+ model = model.eval().to(device) # Move the model to the appropriate device
13
+ return tokenizer, model, device
14
 
15
  # Streamlit interface
16
  st.title("OCR Application using General OCR Theory (GOT) 2.0")
 
24
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
25
 
26
  # Load model
27
+ tokenizer, model, device = load_model()
28
 
29
+ # Load the image
30
+ image = Image.open(uploaded_file)
31
+ image.save("temp_image.png") # Save the uploaded image to a temporary file
32
 
 
 
33
  # Perform OCR
34
  with st.spinner("Extracting text..."):
35
+ res = model.chat(tokenizer, "temp_image.png", ocr_type='ocr')
36
 
37
  # Display the result
38
  st.write("Extracted Text:")