DeepDiveDev commited on
Commit
6372edc
·
verified ·
1 Parent(s): 8976d30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -17,7 +17,7 @@ reader = easyocr.Reader(['en', 'hi']) # 'en' for English, 'hi' for Hindi
17
  # Load the GOT-OCR2 model and tokenizer
18
  tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
19
 
20
- # Load the model and move it to the correct device (GPU if available, else CPU)
21
  model = AutoModel.from_pretrained(
22
  'stepfun-ai/GOT-OCR2_0',
23
  trust_remote_code=True,
@@ -25,9 +25,19 @@ model = AutoModel.from_pretrained(
25
  use_safetensors=True,
26
  pad_token_id=tokenizer.eos_token_id
27
  )
28
- model = model.to(device) # Move the model to the correct device
29
  model = model.eval()
30
 
 
 
 
 
 
 
 
 
 
 
31
  # Load MarianMT translation model for Hindi to English translation
32
  translation_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
33
  translation_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
 
17
  # Load the GOT-OCR2 model and tokenizer
18
  tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
19
 
20
+ # Load the model and move it to the correct device
21
  model = AutoModel.from_pretrained(
22
  'stepfun-ai/GOT-OCR2_0',
23
  trust_remote_code=True,
 
25
  use_safetensors=True,
26
  pad_token_id=tokenizer.eos_token_id
27
  )
28
+ model = model.to(device) # Move model to appropriate device
29
  model = model.eval()
30
 
31
+ # Override the chat function to remove hardcoded .cuda()
32
+ def modified_chat(inputs, ocr_type='ocr', ocr_box='', render=False, save_render_file=''):
33
+ input_ids = torch.as_tensor(inputs.input_ids).to(device) # Use .to(device)
34
+ # Replace all .cuda() calls with .to(device)
35
+ # Handle the remaining logic as needed in the function
36
+ # ...
37
+
38
+ # Replace the model's chat method with the modified version
39
+ model.chat = modified_chat
40
+
41
  # Load MarianMT translation model for Hindi to English translation
42
  translation_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
43
  translation_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-hi-en')