WXM2000 commited on
Commit
4ef25f5
·
1 Parent(s): 7addb3f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import warnings
3
+
4
+ from transformers import AutoModelForTokenClassification,AutoTokenizer,pipeline
5
+ import gradio as gr
6
+ import torch
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from PIL import Image
12
+ from torch import LongTensor, FloatTensor
13
+ from torch.autograd import Function
14
+ from torchvision.transforms.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import BertModel, BertConfig, BertTokenizer
16
+ from torch.utils.data import Dataset, DataLoader
17
+
18
+ warnings.filterwarnings('ignore')
19
+
20
+
21
+ class Exp(Function):
22
+ @staticmethod
23
+ def forward(ctx, i):
24
+ result = i.exp()
25
+ ctx.save_for_backward(result)
26
+ return result
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad_output):
30
+ result, = ctx.saved_tensors
31
+ return grad_output * result
32
+
33
+
34
+ class ReverseLayerF(Function):
35
+
36
+ # @staticmethod
37
+ def forward(self, x, args):
38
+ self.lambd = args.lambd
39
+ return x.view_as(x)
40
+
41
+ # @staticmethod
42
+ def backward(self, grad_output):
43
+ return (grad_output * -self.lambd)
44
+
45
+
46
+ def grad_reverse(x):
47
+ return Exp.apply(x)
48
+
49
+
50
+ class Config():
51
+ def __init__(self):
52
+ self.batch_size = 16
53
+ self.epochs = 200
54
+ self.bert_path = "./fake-news-bert/"
55
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
+ self.event_num = 30
57
+
58
+
59
+ class FakeNewsDataset(Dataset):
60
+ def __init__(self, input_three, event, image, label):
61
+ self.event = LongTensor(list(event))
62
+ self.image = LongTensor([np.array(i) for i in image])
63
+ self.label = LongTensor(list(label))
64
+ self.input_three = self.input_three
65
+ self.input_three[0] = LongTensor(self.input_three[0])
66
+ self.input_three[1] = LongTensor(self.input_three[1])
67
+ self.input_three[2] = LongTensor(self.input_three[2])
68
+
69
+ def __len__(self):
70
+ return len(self.label)
71
+
72
+ def __getitem__(self, idx):
73
+ return self.input_three[0][idx], self.input_three[2][idx], self.input_three[2][idx], self.image[idx], \
74
+ self.event[idx], self.label[idx]
75
+
76
+
77
+ class Multi_Model(nn.Module):
78
+ def __init__(self, bert_path, event_num, classes=2, p=10):
79
+ super(Multi_Model, self).__init__()
80
+ self.config = BertConfig.from_pretrained("./fake-news-bert/config.json") # 导入模型超参数
81
+ self.bert = BertModel.from_pretrained(bert_path, config=self.config) # 加载预训练模型权重
82
+ self.fc = nn.Linear(self.config.hidden_size, p) # 直接分类
83
+ self.event_num = event_num
84
+ '''
85
+ vgg_19 = torchvision.models.vgg19(pretrained=True)
86
+ for param in vgg_19.parameters():
87
+ param.requires_grad = False
88
+ num_ftrs = vgg_19.classifier._modules['6'].out_features
89
+ self.vgg = vgg_19
90
+ '''
91
+ # self.image_fc1 = nn.Linear(num_ftrs, p)
92
+ # input 3*224*224
93
+
94
+ self.cnn = nn.Sequential(
95
+ nn.Conv2d(3, 1, kernel_size=5, stride=2, padding=2), # 1 * 112*112
96
+ nn.ReLU(),
97
+ nn.MaxPool2d(2), # 1*56*56
98
+ nn.Conv2d(1, 1, kernel_size=5, stride=2, padding=0), # 1*26*26
99
+ nn.ReLU(),
100
+ )
101
+ self.image_fc = nn.Sequential(
102
+ nn.Linear(1 * 26 * 26, 26),
103
+ nn.Linear(26, p),
104
+
105
+ )
106
+
107
+ # self.image_classifier = nn.Sequential(
108
+ # K.VisionTransformer(image_size=224, patch_size=16),
109
+ # K.ClassificationHead(num_classes=10)#adjust needed when p change
110
+ # )
111
+
112
+ self.softmax = nn.Softmax(dim=1)
113
+ self.class_classifier = nn.Sequential()
114
+ self.class_classifier.add_module(
115
+ 'c_fc1', nn.Linear(2 * p, p))
116
+ self.class_classifier.add_module('c_fc2', nn.Linear(p, 2))
117
+ self.class_classifier.add_module('c_softmax', nn.Softmax(dim=1))
118
+
119
+ self.domain_classifier = nn.Sequential()
120
+ self.domain_classifier.add_module(
121
+ 'd_fc1', nn.Linear(2 * p, p))
122
+ # self.domain_classifier.add_module('d_bn1', nn.BatchNorm2d(self.hidden_size))
123
+ self.domain_classifier.add_module('d_relu1', nn.LeakyReLU(True))
124
+ self.domain_classifier.add_module(
125
+ 'd_fc2', nn.Linear(p, self.event_num))
126
+ self.domain_classifier.add_module('d_softmax', nn.Softmax(dim=1))
127
+
128
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, image=None):
129
+ outputs = self.bert(input_ids, attention_mask, token_type_ids)
130
+ out_pool = outputs[1] # 池化后的输出 [bs, config.hidden_size]
131
+ text = self.fc(out_pool) # [bs, classes]
132
+ # image = self.vgg(image) # [N, 512]
133
+ # image = F.leaky_relu(self.image_fc1(image))
134
+ image = self.cnn(image)
135
+ # image = self.image_classifier(image)
136
+ image = self.image_fc(image.view(image.size(0), -1))
137
+ text_image = torch.cat((text, image), 1)
138
+
139
+ class_output = self.class_classifier(text_image)
140
+ reverse_feature = grad_reverse(text_image)
141
+ domain_output = self.domain_classifier(reverse_feature)
142
+ return class_output, domain_output
143
+
144
+
145
+ def cleanSST(string):
146
+ string = re.sub(u"[,。 :,.;|-“”——_/nbsp+&;@、《》~()())#O!:【】]", "", string)
147
+ return string.strip().lower()
148
+
149
+ image_path = './1.jpg'
150
+ example_image = Image.open(image_path)
151
+ example_text = '2024年是世界末日,我们完蛋了,世界要毁灭了'
152
+ def predict(input_text,input_image):
153
+ data_transforms = Compose(transforms=[
154
+ Resize(256),
155
+ CenterCrop(224),
156
+ ToTensor(),
157
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
158
+ ])
159
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
160
+ text = ""
161
+ text = input_text
162
+ # text = '2024年是世界末日,我们完蛋了,世界要毁灭了'
163
+ # image_path = '1.jpg'
164
+ multi_model = Multi_Model("./fake-news-bert/", 30) # 这个30不用管
165
+ multi_model.eval()
166
+ multi_model.load_state_dict(torch.load('./fake-news-bert/best_multi_bert_model.pth'))
167
+ # im = Image.open(image_path).convert('RGB')
168
+ im = input_image.convert('RGB')
169
+ # im = Image.fromarray(input_image).convert('RGB')
170
+ im = data_transforms(im)
171
+ # 该文件夹下存放三个文件('vocab.txt', 'pytorch_model.bin', 'config.json')
172
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
173
+ input_ids, input_masks, input_types, = [], [], []
174
+ encode_dict = tokenizer.encode_plus(text=cleanSST(text), max_length=50,
175
+ padding='max_length', truncation=True)
176
+ multi_model.to(device)
177
+
178
+ input_ids.append(encode_dict['input_ids'])
179
+ input_types.append(encode_dict['token_type_ids'])
180
+ input_masks.append(encode_dict['attention_mask'])
181
+ label_pred, yyy = multi_model(LongTensor(input_ids).to(device), LongTensor(input_types).to(device),
182
+ LongTensor(input_masks).to(device), FloatTensor([np.array(im)]).to(device))
183
+
184
+ print(label_pred.shape)
185
+ print(label_pred)
186
+ y_pred = torch.argmax(label_pred, dim=1).detach().cpu().numpy().tolist()
187
+
188
+ print(y_pred)
189
+ print("fake news :", text)
190
+ # print("image path:", image_path)
191
+ if y_pred[0] == 0:
192
+ # print('Real News')
193
+ output_text = '真实新闻'
194
+ else:
195
+ # print('Fake News')
196
+ output_text = '虚假新闻'
197
+ return output_text
198
+
199
+ # examples=['2024年是世界末日,我们完蛋了,世界要毁灭了']
200
+
201
+ # demo = gr.Interface(predict,
202
+ # inputs=[gr.Textbox(lines=2,placeholer="在这里输入需要检测新闻的文本内容"),"image"],
203
+ # outputs="text")#,
204
+ # # examples=examples)
205
+ css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
206
+ with gr.Blocks(css = css) as demo:
207
+ gr.Markdown("<h1><center>虚假新闻检测</center></h1>")
208
+ with gr.Row():
209
+ with gr.Column():
210
+ inp_txt = gr.Textbox(lines=2,placeholer="在这里输入需要检测新闻的文本内容")
211
+ inp_img = gr.Image(type='pil')
212
+ inp = [inp_txt,inp_img]
213
+ with gr.Column():
214
+ out = gr.Textbox(lines=2)
215
+ btn = gr.Button("检测")
216
+ btn.click(fn=predict,inputs=inp,outputs=out)
217
+
218
+ examples = [[example_text,image_path]]
219
+ gr.Examples(
220
+ examples = examples,
221
+ inputs = inp ,
222
+ )
223
+
224
+
225
+ demo.launch()