vittoriopippi commited on
Commit
160658a
·
1 Parent(s): af10767

Change imports

Browse files
Files changed (2) hide show
  1. modeling_vatrpp.py +5 -0
  2. models/model.py +1 -1
modeling_vatrpp.py CHANGED
@@ -10,6 +10,7 @@ from .data.dataset import FolderDataset
10
  from .models.model import VATr
11
  from .models.util.vision import detect_text_bounds
12
  from torchvision.transforms.functional import to_pil_image
 
13
 
14
 
15
  def get_long_tail_chars():
@@ -26,6 +27,10 @@ class VATrPP(PreTrainedModel):
26
 
27
  def __init__(self, config: VATrPPConfig) -> None:
28
  super().__init__(config)
 
 
 
 
29
  self.model = VATr(config)
30
  self.model.eval()
31
 
 
10
  from .models.model import VATr
11
  from .models.util.vision import detect_text_bounds
12
  from torchvision.transforms.functional import to_pil_image
13
+ from huggingface_hub import hf_hub_download
14
 
15
 
16
  def get_long_tail_chars():
 
27
 
28
  def __init__(self, config: VATrPPConfig) -> None:
29
  super().__init__(config)
30
+
31
+ config.english_words_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename=config.english_words_path)
32
+ config.mytext_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename='mytext.txt')
33
+
34
  self.model = VATr(config)
35
  self.model.eval()
36
 
models/model.py CHANGED
@@ -260,7 +260,7 @@ class VATr(nn.Module):
260
 
261
  self.epoch = 0
262
 
263
- with open('mytext.txt', 'r', encoding='utf-8') as f:
264
  self.text = f.read()
265
  self.text = self.text.replace('\n', ' ')
266
  self.text = self.text.replace('\n', ' ')
 
260
 
261
  self.epoch = 0
262
 
263
+ with open(args.mytext_path, 'r', encoding='utf-8') as f:
264
  self.text = f.read()
265
  self.text = self.text.replace('\n', ' ')
266
  self.text = self.text.replace('\n', ' ')