LutherYTT commited on
Commit
65eb2bc
·
1 Parent(s): b5ca05a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -5,11 +5,10 @@ from safetensors.torch import load_file
5
  from transformers import AutoTokenizer, AutoModel
6
  import gc
7
 
8
- # 清理内存
9
  gc.collect()
10
  torch.cuda.empty_cache()
11
 
12
- # 1. 定义MultiTaskRoberta模型架构
13
  class MultiTaskRoberta(nn.Module):
14
  def __init__(self, base_model):
15
  super().__init__()
@@ -24,30 +23,27 @@ class MultiTaskRoberta(nn.Module):
24
  regs = self.regressor(pooled)
25
  return {"logits": logits, "regression_outputs": regs}
26
 
27
- # 2. 准备模型和tokenizer
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- print(f"使用设备: {device}")
30
 
31
- # 加载tokenizer
32
  tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
33
 
34
- # 加载模型
35
  base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
36
  model = MultiTaskRoberta(base_model)
37
 
38
- # 加载权重
39
- model_path = "model.safetensors"
40
  state_dict = load_file(model_path, device="cpu")
41
  model.load_state_dict(state_dict)
42
  model.to(device)
43
  model.eval()
44
 
45
- # 使用半精度减少内存占用
46
  # if device.type == 'cuda':
47
  # model.half()
48
- # print("使用半精度模型")
49
 
50
- # 3. 优化后的推理函数
51
  def predict(text: str):
52
  try:
53
  inputs = tokenizer(
@@ -58,7 +54,6 @@ def predict(text: str):
58
  max_length=128
59
  )
60
 
61
- # 将输入移到设备
62
  inputs = {k: v.to(device) for k, v in inputs.items()}
63
 
64
  with torch.no_grad():
@@ -71,7 +66,6 @@ def predict(text: str):
71
  pred_class = torch.argmax(out["logits"], dim=-1).item()
72
  sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
73
 
74
- # 将结果移回CPU处理
75
  reg_results = out["regression_outputs"][0].cpu().numpy()
76
  rating, delight, anger, sorrow, happiness = reg_results
77
 
@@ -86,7 +80,7 @@ def predict(text: str):
86
  except Exception as e:
87
  return {"错误": f"处理失败: {str(e)}"}
88
 
89
- # 4. 创建Gradio界面
90
  iface = gr.Interface(
91
  fn=predict,
92
  inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
@@ -108,6 +102,5 @@ iface = gr.Interface(
108
  ]
109
  )
110
 
111
- # 5. 启动应用 - 使用兼容的启动方式
112
  if __name__ == "__main__":
113
  iface.launch(share=True, show_error=True)
 
5
  from transformers import AutoTokenizer, AutoModel
6
  import gc
7
 
8
+ # Release memory
9
  gc.collect()
10
  torch.cuda.empty_cache()
11
 
 
12
  class MultiTaskRoberta(nn.Module):
13
  def __init__(self, base_model):
14
  super().__init__()
 
23
  regs = self.regressor(pooled)
24
  return {"logits": logits, "regression_outputs": regs}
25
 
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Device: {device}")
28
 
29
+ # Load tokenizer
30
  tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
31
 
32
+ # Load base model
33
  base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
34
  model = MultiTaskRoberta(base_model)
35
 
36
+ # Load safetensors
37
+ model_path = "model1.safetensors"
38
  state_dict = load_file(model_path, device="cpu")
39
  model.load_state_dict(state_dict)
40
  model.to(device)
41
  model.eval()
42
 
43
+ # Use half precision to reduce memory usage
44
  # if device.type == 'cuda':
45
  # model.half()
 
46
 
 
47
  def predict(text: str):
48
  try:
49
  inputs = tokenizer(
 
54
  max_length=128
55
  )
56
 
 
57
  inputs = {k: v.to(device) for k, v in inputs.items()}
58
 
59
  with torch.no_grad():
 
66
  pred_class = torch.argmax(out["logits"], dim=-1).item()
67
  sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
68
 
 
69
  reg_results = out["regression_outputs"][0].cpu().numpy()
70
  rating, delight, anger, sorrow, happiness = reg_results
71
 
 
80
  except Exception as e:
81
  return {"错误": f"处理失败: {str(e)}"}
82
 
83
+ # Create Gradio interface
84
  iface = gr.Interface(
85
  fn=predict,
86
  inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
 
102
  ]
103
  )
104
 
 
105
  if __name__ == "__main__":
106
  iface.launch(share=True, show_error=True)