File size: 2,893 Bytes
8428b8e
7fcfcdc
30a40c5
8428b8e
697c4fb
 
 
8428b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a40c5
 
 
 
 
 
 
 
 
8428b8e
 
ea6e2e4
8428b8e
285e2b7
e43fec0
7b45115
8428b8e
30a40c5
e43fec0
30a40c5
8428b8e
15936c4
8428b8e
7fcfcdc
30a40c5
8428b8e
e43fec0
8428b8e
 
 
 
 
 
285e2b7
8428b8e
e43fec0
285e2b7
3aa1ebd
51f25dc
d821fec
22c71f8
 
51f25dc
1878bd2
51f25dc
1878bd2
285e2b7
ea6e2e4
e43fec0
ea6e2e4
9db1c97
ea6e2e4
30a40c5
 
ea6e2e4
285e2b7
 
8428b8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util

#使页面布局更宽
st.set_page_config(layout="wide")

def app():
    # 创建Streamlit应用程序

    st.title("对比句子的相似度")

    source_text = st.text_input("源句子", value="")

    st.write("待比较的句子:")

    if "inputs" not in st.session_state:
	    # 创建一个空列表来存储输入框列表
        st.session_state.inputs = []
        st.session_state.inputs_index = 0


    with st.container():
        # 在容器中渲染已经存在的输入框列表
        for i in range(0, st.session_state.inputs_index):
            st.session_state.inputs[i]= st.text_input(f"请输入第 {i+1} 个句子", "", key=i)

    # 创建一个添加输入框的按钮
    add_input_button = st.button("添加一个待比较句子")

	# 当用户点击按钮时往容器中添加新的输入框
    if add_input_button:
        i = st.session_state.inputs_index
        st.session_state.inputs.append(st.text_input(f"请输入第 {i+1} 个句子", "", key=i))
        # 自增输入框的key
        st.session_state.inputs_index += 1
    
    button_generate = st.button("计算")
    button_clear = st.button("清空")

    def transformer(source_text, sentences):
        #使用模型分别计算源字符串和对比字符串的embedding
        model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device="cpu")
        source_emb = model.encode(source_text, convert_to_tensor=True)
        sent_embs = model.encode(sentences, convert_to_tensor=True)
        #计算源字符串和对比字符串embedding的cos值
        cos_sim = util.cos_sim(source_emb, sent_embs)

        cosin_dict = {}

        for i, cos in enumerate(torch.flatten(cos_sim)):
            cosin_dict[sentences[i]] = cos
        
        #根据cos值降序排列
        sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True))
        return sorted_dict


    if button_generate:
        # embeddings
        embeddings = transformer([source_text], st.session_state.inputs)

        # 显示对比的字符串、进度条、cos值
        with st.container():
            for sent, cos in embeddings.items():
                col1, col2, col3 = st.columns(3)
                cos_value = round(float(cos.item()),4)
                with col1:
                    st.text(sent)
                with col2:
                    bar = st.progress(cos_value)
                with col3:
                    st.text(cos_value)

    if button_clear:
        #清空对比字符串
        st.session_state.inputs.clear()
        del st.session_state["inputs"]
        st.session_state.inputs_index = 0
        source_text = ''
        st.experimental_rerun()

   

if __name__ == "__main__":
    # 运行应用程序
    app()