Update app.py
Browse files
app.py
CHANGED
@@ -5,11 +5,10 @@ from safetensors.torch import load_file
|
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
import gc
|
7 |
|
8 |
-
#
|
9 |
gc.collect()
|
10 |
torch.cuda.empty_cache()
|
11 |
|
12 |
-
# 1. 定义MultiTaskRoberta模型架构
|
13 |
class MultiTaskRoberta(nn.Module):
|
14 |
def __init__(self, base_model):
|
15 |
super().__init__()
|
@@ -24,30 +23,27 @@ class MultiTaskRoberta(nn.Module):
|
|
24 |
regs = self.regressor(pooled)
|
25 |
return {"logits": logits, "regression_outputs": regs}
|
26 |
|
27 |
-
# 2. 准备模型和tokenizer
|
28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
-
print(f"
|
30 |
|
31 |
-
#
|
32 |
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
|
33 |
|
34 |
-
#
|
35 |
base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
|
36 |
model = MultiTaskRoberta(base_model)
|
37 |
|
38 |
-
#
|
39 |
-
model_path = "
|
40 |
state_dict = load_file(model_path, device="cpu")
|
41 |
model.load_state_dict(state_dict)
|
42 |
model.to(device)
|
43 |
model.eval()
|
44 |
|
45 |
-
#
|
46 |
# if device.type == 'cuda':
|
47 |
# model.half()
|
48 |
-
# print("使用半精度模型")
|
49 |
|
50 |
-
# 3. 优化后的推理函数
|
51 |
def predict(text: str):
|
52 |
try:
|
53 |
inputs = tokenizer(
|
@@ -58,7 +54,6 @@ def predict(text: str):
|
|
58 |
max_length=128
|
59 |
)
|
60 |
|
61 |
-
# 将输入移到设备
|
62 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
63 |
|
64 |
with torch.no_grad():
|
@@ -71,7 +66,6 @@ def predict(text: str):
|
|
71 |
pred_class = torch.argmax(out["logits"], dim=-1).item()
|
72 |
sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
|
73 |
|
74 |
-
# 将结果移回CPU处理
|
75 |
reg_results = out["regression_outputs"][0].cpu().numpy()
|
76 |
rating, delight, anger, sorrow, happiness = reg_results
|
77 |
|
@@ -86,7 +80,7 @@ def predict(text: str):
|
|
86 |
except Exception as e:
|
87 |
return {"错误": f"处理失败: {str(e)}"}
|
88 |
|
89 |
-
#
|
90 |
iface = gr.Interface(
|
91 |
fn=predict,
|
92 |
inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
|
@@ -108,6 +102,5 @@ iface = gr.Interface(
|
|
108 |
]
|
109 |
)
|
110 |
|
111 |
-
# 5. 启动应用 - 使用兼容的启动方式
|
112 |
if __name__ == "__main__":
|
113 |
iface.launch(share=True, show_error=True)
|
|
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
import gc
|
7 |
|
8 |
+
# Release memory
|
9 |
gc.collect()
|
10 |
torch.cuda.empty_cache()
|
11 |
|
|
|
12 |
class MultiTaskRoberta(nn.Module):
|
13 |
def __init__(self, base_model):
|
14 |
super().__init__()
|
|
|
23 |
regs = self.regressor(pooled)
|
24 |
return {"logits": logits, "regression_outputs": regs}
|
25 |
|
|
|
26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
print(f"Device: {device}")
|
28 |
|
29 |
+
# Load tokenizer
|
30 |
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
|
31 |
|
32 |
+
# Load base model
|
33 |
base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
|
34 |
model = MultiTaskRoberta(base_model)
|
35 |
|
36 |
+
# Load safetensors
|
37 |
+
model_path = "model1.safetensors"
|
38 |
state_dict = load_file(model_path, device="cpu")
|
39 |
model.load_state_dict(state_dict)
|
40 |
model.to(device)
|
41 |
model.eval()
|
42 |
|
43 |
+
# Use half precision to reduce memory usage
|
44 |
# if device.type == 'cuda':
|
45 |
# model.half()
|
|
|
46 |
|
|
|
47 |
def predict(text: str):
|
48 |
try:
|
49 |
inputs = tokenizer(
|
|
|
54 |
max_length=128
|
55 |
)
|
56 |
|
|
|
57 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
58 |
|
59 |
with torch.no_grad():
|
|
|
66 |
pred_class = torch.argmax(out["logits"], dim=-1).item()
|
67 |
sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
|
68 |
|
|
|
69 |
reg_results = out["regression_outputs"][0].cpu().numpy()
|
70 |
rating, delight, anger, sorrow, happiness = reg_results
|
71 |
|
|
|
80 |
except Exception as e:
|
81 |
return {"错误": f"处理失败: {str(e)}"}
|
82 |
|
83 |
+
# Create Gradio interface
|
84 |
iface = gr.Interface(
|
85 |
fn=predict,
|
86 |
inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
|
|
|
102 |
]
|
103 |
)
|
104 |
|
|
|
105 |
if __name__ == "__main__":
|
106 |
iface.launch(share=True, show_error=True)
|