ssocean commited on
Commit
1fd1dd9
·
verified ·
1 Parent(s): 83d6a82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -39
app.py CHANGED
@@ -4,8 +4,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch.nn.functional as F
5
  import torch.nn as nn
6
  import re
7
-
8
- model_path = "ssocean/NAIP" # 更换为你的模型路径
9
  model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=1, load_in_8bit=True)
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -14,45 +13,17 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  model.eval()
16
 
17
- def validate_input(title, abstract):
18
- """验证输入是否符合要求"""
19
-
20
- # 黑名单:屏蔽非拉丁字符
21
- non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
22
- if len(title.split(' '))<4:
23
- return False, "The title must be at least 3 words long."
24
- if len(abstract.split(' ')) < 50:
25
- return False, "The abstract must be at least 50 words long."
26
- if len((title + abstract).split(' '))>1024:
27
- return True, "Warning, The input length is approaching tokenization limits (1024) and may be truncated without further warning!"
28
- if non_latin_pattern.search(title):
29
- return False, "The title contains invalid characters. Only English letters and special symbols are allowed."
30
- if non_latin_pattern.search(abstract):
31
- return False, "The abstract contains invalid characters. Only English letters and special symbols are allowed."
32
-
33
- return True, "Inputs are valid! Good to go!"
34
-
35
- def update_button_status(title, abstract):
36
- """根据输入内容动态更新按钮状态"""
37
- valid, message = validate_input(title, abstract)
38
- if not valid:
39
- return gr.update(value="Error: " + message), gr.update(interactive=False)
40
- return gr.update(value=message), gr.update(interactive=True)
41
-
42
 
43
  def predict(title, abstract):
44
  text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
45
  inputs = tokenizer(text, return_tensors="pt")
46
  with torch.no_grad():
47
  outputs = model(**inputs.to(device))
48
- probability = torch.sigmoid(outputs.logits).item() + 0.05
49
  # reason for +0.05: We observed that the predicted values in the web demo are generally around 0.05 lower than those in the local deployment (due to differences in software/hardware environments). Therefore, we applied the following compensation in the web demo. Please do not use this in the local deployment.
50
-
51
- # Clamp the value to ensure it is between 0 and 1 (for probabilities)
52
- clamped_probability = torch.clamp(probability, min=0.0, max=1.0)
53
-
54
- # Return the clamped probability, rounded to 4 decimal places
55
- return round(clamped_probability, 4)
56
 
57
 
58
  # 示例数据
@@ -106,6 +77,31 @@ examples = [
106
  ]
107
  ]
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # 创建 Gradio 界面
110
  with gr.Blocks() as iface:
111
  gr.Markdown("""
@@ -160,10 +156,6 @@ with gr.Blocks() as iface:
160
  - Predicted impact is a probabilistic value generated by the model and does not reflect paper quality or novelty.
161
  - The author takes no responsibility for the prediction results.
162
  - To identify potentially impactful papers, this study uses the sigmoid+MSE approach to optimize NDCG values (over sigmoid+BCE), resulting in predicted values concentrated between 0.1 and 0.9 due to the sigmoid gradient effect.
163
- - Generally, it is considered a predicted influence score greater than 0.65 to indicate an exceptionally impactful paper.
164
  """)
165
  iface.launch()
166
-
167
-
168
-
169
-
 
4
  import torch.nn.functional as F
5
  import torch.nn as nn
6
  import re
7
+ model_path = r'ssocean/NAIP'
 
8
  model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=1, load_in_8bit=True)
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
13
 
14
  model.eval()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def predict(title, abstract):
18
  text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
19
  inputs = tokenizer(text, return_tensors="pt")
20
  with torch.no_grad():
21
  outputs = model(**inputs.to(device))
22
+ probability = torch.sigmoid(outputs.logits).item()
23
  # reason for +0.05: We observed that the predicted values in the web demo are generally around 0.05 lower than those in the local deployment (due to differences in software/hardware environments). Therefore, we applied the following compensation in the web demo. Please do not use this in the local deployment.
24
+ if probability + 0.05 >=1.0:
25
+ return round(1, 4)
26
+ return round(probability + 0.05, 4)
 
 
 
27
 
28
 
29
  # 示例数据
 
77
  ]
78
  ]
79
 
80
+ def validate_input(title, abstract):
81
+ """验证输入是否符合要求"""
82
+
83
+ # 黑名单:屏蔽非拉丁字符
84
+ non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
85
+ if len(title.split(' '))<4:
86
+ return False, "The title must be at least 3 words long."
87
+ if len(abstract.split(' ')) < 50:
88
+ return False, "The abstract must be at least 50 words long."
89
+ if len((title + abstract).split(' '))>1024:
90
+ return True, "Warning, The input length is approaching tokenization limits (1024) and may be truncated without further warning!"
91
+ if non_latin_pattern.search(title):
92
+ return False, "The title contains invalid characters. Only English letters and special symbols are allowed."
93
+ if non_latin_pattern.search(abstract):
94
+ return False, "The abstract contains invalid characters. Only English letters and special symbols are allowed."
95
+
96
+ return True, "Inputs are valid! Good to go!"
97
+
98
+ def update_button_status(title, abstract):
99
+ """根据输入内容动态更新按钮状态"""
100
+ valid, message = validate_input(title, abstract)
101
+ if not valid:
102
+ return gr.update(value="Error: " + message), gr.update(interactive=False)
103
+ return gr.update(value=message), gr.update(interactive=True)
104
+
105
  # 创建 Gradio 界面
106
  with gr.Blocks() as iface:
107
  gr.Markdown("""
 
156
  - Predicted impact is a probabilistic value generated by the model and does not reflect paper quality or novelty.
157
  - The author takes no responsibility for the prediction results.
158
  - To identify potentially impactful papers, this study uses the sigmoid+MSE approach to optimize NDCG values (over sigmoid+BCE), resulting in predicted values concentrated between 0.1 and 0.9 due to the sigmoid gradient effect.
159
+ - Generally, it is considered a predicted influence score greater than 0.65 to indicate an impactful paper.
160
  """)
161
  iface.launch()