DeepDiveDev commited on
Commit
2b1ac83
·
verified ·
1 Parent(s): 6372edc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -29,11 +29,12 @@ model = model.to(device) # Move model to appropriate device
29
  model = model.eval()
30
 
31
  # Override the chat function to remove hardcoded .cuda()
32
- def modified_chat(inputs, ocr_type='ocr', ocr_box='', render=False, save_render_file=''):
33
  input_ids = torch.as_tensor(inputs.input_ids).to(device) # Use .to(device)
34
- # Replace all .cuda() calls with .to(device)
35
- # Handle the remaining logic as needed in the function
36
- # ...
 
37
 
38
  # Replace the model's chat method with the modified version
39
  model.chat = modified_chat
@@ -68,7 +69,7 @@ if image_file is not None:
68
  if st.button("Run OCR"):
69
  # Use GOT-OCR2 model for plain text OCR (structured documents)
70
  with torch.no_grad():
71
- res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr')
72
 
73
  # Perform formatted text OCR
74
  with torch.no_grad():
 
29
  model = model.eval()
30
 
31
  # Override the chat function to remove hardcoded .cuda()
32
+ def modified_chat(inputs, *args, ocr_type='ocr', **kwargs):
33
  input_ids = torch.as_tensor(inputs.input_ids).to(device) # Use .to(device)
34
+ # Additional processing logic here
35
+ # Example: replace with actual model inference code if necessary
36
+ # res = model(input_ids)
37
+ return f"Processed input: {input_ids}, OCR Type: {ocr_type}"
38
 
39
  # Replace the model's chat method with the modified version
40
  model.chat = modified_chat
 
69
  if st.button("Run OCR"):
70
  # Use GOT-OCR2 model for plain text OCR (structured documents)
71
  with torch.no_grad():
72
+ res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr') # Ensure the correct parameters are passed
73
 
74
  # Perform formatted text OCR
75
  with torch.no_grad():