Spaces:
Runtime error
Runtime error
File size: 2,893 Bytes
8428b8e 7fcfcdc 30a40c5 8428b8e 697c4fb 8428b8e 30a40c5 8428b8e ea6e2e4 8428b8e 285e2b7 e43fec0 7b45115 8428b8e 30a40c5 e43fec0 30a40c5 8428b8e 15936c4 8428b8e 7fcfcdc 30a40c5 8428b8e e43fec0 8428b8e 285e2b7 8428b8e e43fec0 285e2b7 3aa1ebd 51f25dc d821fec 22c71f8 51f25dc 1878bd2 51f25dc 1878bd2 285e2b7 ea6e2e4 e43fec0 ea6e2e4 9db1c97 ea6e2e4 30a40c5 ea6e2e4 285e2b7 8428b8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer, util
#使页面布局更宽
st.set_page_config(layout="wide")
def app():
# 创建Streamlit应用程序
st.title("对比句子的相似度")
source_text = st.text_input("源句子", value="")
st.write("待比较的句子:")
if "inputs" not in st.session_state:
# 创建一个空列表来存储输入框列表
st.session_state.inputs = []
st.session_state.inputs_index = 0
with st.container():
# 在容器中渲染已经存在的输入框列表
for i in range(0, st.session_state.inputs_index):
st.session_state.inputs[i]= st.text_input(f"请输入第 {i+1} 个句子", "", key=i)
# 创建一个添加输入框的按钮
add_input_button = st.button("添加一个待比较句子")
# 当用户点击按钮时往容器中添加新的输入框
if add_input_button:
i = st.session_state.inputs_index
st.session_state.inputs.append(st.text_input(f"请输入第 {i+1} 个句子", "", key=i))
# 自增输入框的key
st.session_state.inputs_index += 1
button_generate = st.button("计算")
button_clear = st.button("清空")
def transformer(source_text, sentences):
#使用模型分别计算源字符串和对比字符串的embedding
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device="cpu")
source_emb = model.encode(source_text, convert_to_tensor=True)
sent_embs = model.encode(sentences, convert_to_tensor=True)
#计算源字符串和对比字符串embedding的cos值
cos_sim = util.cos_sim(source_emb, sent_embs)
cosin_dict = {}
for i, cos in enumerate(torch.flatten(cos_sim)):
cosin_dict[sentences[i]] = cos
#根据cos值降序排列
sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True))
return sorted_dict
if button_generate:
# embeddings
embeddings = transformer([source_text], st.session_state.inputs)
# 显示对比的字符串、进度条、cos值
with st.container():
for sent, cos in embeddings.items():
col1, col2, col3 = st.columns(3)
cos_value = round(float(cos.item()),4)
with col1:
st.text(sent)
with col2:
bar = st.progress(cos_value)
with col3:
st.text(cos_value)
if button_clear:
#清空对比字符串
st.session_state.inputs.clear()
del st.session_state["inputs"]
st.session_state.inputs_index = 0
source_text = ''
st.experimental_rerun()
if __name__ == "__main__":
# 运行应用程序
app() |