TongkunGuan commited on
Commit
be042ec
·
verified ·
1 Parent(s): 7ff8aa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -7,7 +7,8 @@ import torchvision.transforms as T
7
  from transformers import AutoTokenizer
8
  import gradio as gr
9
  from resnet50 import build_model
10
- from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
 
11
  from utils import IMAGENET_MEAN, IMAGENET_STD
12
  from internvl.train.dataset import dynamic_preprocess
13
  from internvl.model.internvl_chat import InternVLChatModel
@@ -42,20 +43,16 @@ def load_model(check_type):
42
  elif 'TokenFD' in check_type:
43
  model_path = CHECKPOINTS[check_type]
44
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
45
- model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
46
- transform = T.Compose([
47
- T.Lambda(lambda img: img.convert('RGB')),
48
- T.Resize((224, 224)),
49
- T.ToTensor(),
50
- T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
51
- ])
52
-
53
  return model.to(device), tokenizer, transform, device
54
 
55
  def process_image(model, tokenizer, transform, device, check_type, image, text):
56
 
57
  src_size = image.size
58
- if 'TokenOCR' in check_type:
59
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
60
  image_size=model.config.force_image_size,
61
  use_thumbnail=model.config.use_thumbnail,
 
7
  from transformers import AutoTokenizer
8
  import gradio as gr
9
  from resnet50 import build_model
10
+ # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
11
+ from utils import generate_similiarity_map, get_transform, post_process, load_tokenizer, build_transform_R50
12
  from utils import IMAGENET_MEAN, IMAGENET_STD
13
  from internvl.train.dataset import dynamic_preprocess
14
  from internvl.model.internvl_chat import InternVLChatModel
 
43
  elif 'TokenFD' in check_type:
44
  model_path = CHECKPOINTS[check_type]
45
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
46
+ # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
47
+ model = InternVLChatModel.from_pretrained(checkpoint_vit_english, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ,load_in_8bit=False, load_in_4bit=False).eval()
48
+ transform = get_transform(is_train=False, image_size=model.config.force_image_size)
49
+
 
 
 
 
50
  return model.to(device), tokenizer, transform, device
51
 
52
  def process_image(model, tokenizer, transform, device, check_type, image, text):
53
 
54
  src_size = image.size
55
+ if 'TokenFD' in check_type:
56
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
57
  image_size=model.config.force_image_size,
58
  use_thumbnail=model.config.use_thumbnail,