Spaces:
Runtime error
Runtime error
zfzhang-thu
commited on
Commit
•
109f057
1
Parent(s):
ab17c1a
add code for achieve vicuna-7b on huggingface
Browse files- leo/model.py +7 -5
leo/model.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
import contextlib
|
2 |
-
import math
|
3 |
-
|
4 |
import clip
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
from einops import rearrange
|
9 |
from peft import LoraConfig, get_peft_model
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
|
@@ -42,10 +39,12 @@ class SequentialGrounder(torch.nn.Module):
|
|
42 |
def __init__(self,predict_mode=False):
|
43 |
super().__init__()
|
44 |
cfg = {
|
|
|
45 |
"model": {
|
46 |
"llm": {
|
47 |
"name": "Vicuna7B",
|
48 |
"cfg_path": "/scratch/generalvision/vicuna-7b",
|
|
|
49 |
"truncation_side": "right",
|
50 |
"max_context_len": 256,
|
51 |
"max_out_len": 256,
|
@@ -67,8 +66,11 @@ class SequentialGrounder(torch.nn.Module):
|
|
67 |
self.predict_mode = predict_mode
|
68 |
|
69 |
# LLM
|
70 |
-
llm_name = '
|
71 |
-
|
|
|
|
|
|
|
72 |
llm_truncation_side = 'right'
|
73 |
if 'vicuna' in llm_name.lower():
|
74 |
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
|
|
|
1 |
import contextlib
|
|
|
|
|
2 |
import clip
|
3 |
import torch
|
4 |
import torch.nn as nn
|
|
|
5 |
from einops import rearrange
|
6 |
from peft import LoraConfig, get_peft_model
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
|
|
|
39 |
def __init__(self,predict_mode=False):
|
40 |
super().__init__()
|
41 |
cfg = {
|
42 |
+
"launch_mode": "hf",
|
43 |
"model": {
|
44 |
"llm": {
|
45 |
"name": "Vicuna7B",
|
46 |
"cfg_path": "/scratch/generalvision/vicuna-7b",
|
47 |
+
"hf_cfg_path": "huangjy-pku/vicuna-7b",
|
48 |
"truncation_side": "right",
|
49 |
"max_context_len": 256,
|
50 |
"max_out_len": 256,
|
|
|
66 |
self.predict_mode = predict_mode
|
67 |
|
68 |
# LLM
|
69 |
+
llm_name = cfg['model']['llm']['name']
|
70 |
+
if cfg['launch_mode'] == 'hf':
|
71 |
+
llm_cfg_path = cfg['model']['llm']['hf_cfg_path']
|
72 |
+
else:
|
73 |
+
llm_cfg_path = cfg['model']['llm']['cfg_path']
|
74 |
llm_truncation_side = 'right'
|
75 |
if 'vicuna' in llm_name.lower():
|
76 |
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
|