zfzhang-thu commited on
Commit
109f057
1 Parent(s): ab17c1a

add code for achieve vicuna-7b on huggingface

Browse files
Files changed (1) hide show
  1. leo/model.py +7 -5
leo/model.py CHANGED
@@ -1,10 +1,7 @@
1
  import contextlib
2
- import math
3
-
4
  import clip
5
  import torch
6
  import torch.nn as nn
7
- import torch.nn.functional as F
8
  from einops import rearrange
9
  from peft import LoraConfig, get_peft_model
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
@@ -42,10 +39,12 @@ class SequentialGrounder(torch.nn.Module):
42
  def __init__(self,predict_mode=False):
43
  super().__init__()
44
  cfg = {
 
45
  "model": {
46
  "llm": {
47
  "name": "Vicuna7B",
48
  "cfg_path": "/scratch/generalvision/vicuna-7b",
 
49
  "truncation_side": "right",
50
  "max_context_len": 256,
51
  "max_out_len": 256,
@@ -67,8 +66,11 @@ class SequentialGrounder(torch.nn.Module):
67
  self.predict_mode = predict_mode
68
 
69
  # LLM
70
- llm_name = 'Vicuna7B'
71
- llm_cfg_path = '/scratch/generalvision/vicuna-7b'
 
 
 
72
  llm_truncation_side = 'right'
73
  if 'vicuna' in llm_name.lower():
74
  self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
 
1
  import contextlib
 
 
2
  import clip
3
  import torch
4
  import torch.nn as nn
 
5
  from einops import rearrange
6
  from peft import LoraConfig, get_peft_model
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
 
39
  def __init__(self,predict_mode=False):
40
  super().__init__()
41
  cfg = {
42
+ "launch_mode": "hf",
43
  "model": {
44
  "llm": {
45
  "name": "Vicuna7B",
46
  "cfg_path": "/scratch/generalvision/vicuna-7b",
47
+ "hf_cfg_path": "huangjy-pku/vicuna-7b",
48
  "truncation_side": "right",
49
  "max_context_len": 256,
50
  "max_out_len": 256,
 
66
  self.predict_mode = predict_mode
67
 
68
  # LLM
69
+ llm_name = cfg['model']['llm']['name']
70
+ if cfg['launch_mode'] == 'hf':
71
+ llm_cfg_path = cfg['model']['llm']['hf_cfg_path']
72
+ else:
73
+ llm_cfg_path = cfg['model']['llm']['cfg_path']
74
  llm_truncation_side = 'right'
75
  if 'vicuna' in llm_name.lower():
76
  self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)