Upload folder using huggingface_hub
Browse files- app.py +6 -2
- run_gradio.sh +1 -1
- src/ChatWorld.py +0 -2
- src/DataBase/BaseDB.py +7 -1
app.py
CHANGED
@@ -9,6 +9,10 @@ chatWorld = ChatWorld()
|
|
9 |
role_name_list_global = None
|
10 |
role_name_dict_global = None
|
11 |
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def getContent(input_file):
|
14 |
# ่ฏปๅๆไปถๅ
ๅฎน
|
@@ -28,8 +32,8 @@ def getContent(input_file):
|
|
28 |
role_name_dict_global = role_name_dict
|
29 |
|
30 |
return (
|
31 |
-
gr.Radio(choices=role_name_list, interactive=True
|
32 |
-
gr.Radio(choices=role_name_list, interactive=True
|
33 |
)
|
34 |
|
35 |
|
|
|
9 |
role_name_list_global = None
|
10 |
role_name_dict_global = None
|
11 |
|
12 |
+
Meta = {
|
13 |
+
"uuid":"111"
|
14 |
+
}
|
15 |
+
|
16 |
|
17 |
def getContent(input_file):
|
18 |
# ่ฏปๅๆไปถๅ
ๅฎน
|
|
|
32 |
role_name_dict_global = role_name_dict
|
33 |
|
34 |
return (
|
35 |
+
gr.Radio(choices=role_name_list, interactive=True),
|
36 |
+
gr.Radio(choices=role_name_list, interactive=True),
|
37 |
)
|
38 |
|
39 |
|
run_gradio.sh
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
export CUDA_VISIBLE_DEVICES=
|
2 |
export HF_ENDPOINT="https://hf-mirror.com"
|
3 |
|
4 |
# Start the gradio server
|
|
|
1 |
+
export CUDA_VISIBLE_DEVICES=1
|
2 |
export HF_ENDPOINT="https://hf-mirror.com"
|
3 |
|
4 |
# Start the gradio server
|
src/ChatWorld.py
CHANGED
@@ -17,8 +17,6 @@ class ChatWorld:
|
|
17 |
) -> None:
|
18 |
self.model_name = pretrained_model_name_or_path
|
19 |
|
20 |
-
self.global_batch_size = global_batch_size
|
21 |
-
|
22 |
self.client = GLM_api()
|
23 |
|
24 |
if model_load:
|
|
|
17 |
) -> None:
|
18 |
self.model_name = pretrained_model_name_or_path
|
19 |
|
|
|
|
|
20 |
self.client = GLM_api()
|
21 |
|
22 |
if model_load:
|
src/DataBase/BaseDB.py
CHANGED
@@ -5,6 +5,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
5 |
from transformers import AutoTokenizer
|
6 |
from langchain.text_splitter import TokenTextSplitter
|
7 |
from langchain_core.documents import Document
|
|
|
8 |
|
9 |
|
10 |
class BaseDB(metaclass=ABCMeta):
|
@@ -21,7 +22,12 @@ class BaseDB(metaclass=ABCMeta):
|
|
21 |
if not embedding_name:
|
22 |
embedding_name = "BAAI/bge-small-zh-v1.5"
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
|
26 |
|
27 |
self.init_db()
|
|
|
5 |
from transformers import AutoTokenizer
|
6 |
from langchain.text_splitter import TokenTextSplitter
|
7 |
from langchain_core.documents import Document
|
8 |
+
from torch.cuda import is_available
|
9 |
|
10 |
|
11 |
class BaseDB(metaclass=ABCMeta):
|
|
|
22 |
if not embedding_name:
|
23 |
embedding_name = "BAAI/bge-small-zh-v1.5"
|
24 |
|
25 |
+
if is_available():
|
26 |
+
model_kwargs = {"device": "cuda"}
|
27 |
+
else:
|
28 |
+
model_kwargs = {"device": "cpu"}
|
29 |
+
|
30 |
+
self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs)
|
31 |
self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
|
32 |
|
33 |
self.init_db()
|