strongeryongchao commited on
Commit
68a52f8
·
verified ·
1 Parent(s): ad6328e

create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from flask import Flask, request, jsonify
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+ import pandas as pd
6
+ import uuid
7
+ import os
8
+
9
+ # 初始化 Flask(仅用于API)
10
+ flask_app = Flask(__name__)
11
+ os.makedirs("static", exist_ok=True)
12
+
13
+ # 直接使用 Hugging Face 模型(无需本地存储)
14
+ model_name = "uer/roberta-base-finetuned-jd-binary-chinese" # 使用Hub上的模型
15
+ classifier = pipeline("text-classification", model=model_name)
16
+
17
+ @flask_app.route("/api/predict", methods=["POST"])
18
+ def api_predict():
19
+ text = request.json.get("text", "")
20
+ result = classifier(text)
21
+ return jsonify(result)
22
+
23
+ # Gradio 界面
24
+ with gr.Blocks(title="情感分析系统") as demo:
25
+ with gr.Tab("单句分析"):
26
+ input_text = gr.Textbox(label="输入文本")
27
+ analyze_btn = gr.Button("分析")
28
+ output_label = gr.Label(label="结果")
29
+
30
+ def predict(text):
31
+ result = classifier(text)[0]
32
+ return {result["label"]: result["score"]}
33
+
34
+ analyze_btn.click(
35
+ fn=predict,
36
+ inputs=input_text,
37
+ outputs=output_label
38
+ )
39
+
40
+ with gr.Tab("批量分析"):
41
+ file_input = gr.File(label="上传TXT文件")
42
+ batch_output = gr.Dataframe(headers=["文本", "标签", "置信度"])
43
+ process_btn = gr.Button("处理文件")
44
+
45
+ def process_file(file):
46
+ with open(file.name, "r", encoding="utf-8") as f:
47
+ texts = [line.strip() for line in f if line.strip()]
48
+ results = []
49
+ for text in texts:
50
+ pred = classifier(text)[0]
51
+ results.append([text, pred["label"], f"{pred['score']:.4f}"])
52
+ return pd.DataFrame(results, columns=["文本", "标签", "置信度"])
53
+
54
+ process_btn.click(
55
+ fn=process_file,
56
+ inputs=file_input,
57
+ outputs=batch_output
58
+ )
59
+
60
+ # 启动 Flask 线程(仅用于API)
61
+ import threading
62
+ threading.Thread(
63
+ target=flask_app.run,
64
+ kwargs={"port": 8000, "host": "0.0.0.0"}
65
+ ).start()
66
+
67
+ # 启动 Gradio
68
+ demo.launch(server_port=7860)