liuwei commited on
Commit
e43fec0
·
1 Parent(s): d821fec
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -36,19 +36,19 @@ def app():
36
  button_clear = st.button("清空")
37
 
38
  def transformer(source_text, sentences):
 
39
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu")
40
  source_emb = model.encode(source_text, convert_to_tensor=True)
41
  sent_embs = model.encode(sentences, convert_to_tensor=True)
 
42
  cos_sim = util.cos_sim(source_emb, sent_embs)
43
 
44
- st.write(source_emb)
45
- st.write(sent_embs)
46
- st.write(cos_sim) #output tensor([[1.0000, 0.3624]])
47
  cosin_dict = {}
48
 
49
  for i, cos in enumerate(torch.flatten(cos_sim)):
50
  cosin_dict[sentences[i]] = cos
51
 
 
52
  sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True))
53
  return sorted_dict
54
 
@@ -57,16 +57,11 @@ def app():
57
  # embeddings
58
  embeddings = transformer([source_text], st.session_state.inputs)
59
 
60
- # 显示生成的文本
61
- st.write(embeddings)
62
- #output_text.success(generated_text)
63
-
64
  with st.container():
65
  for sent, cos in embeddings.items():
66
  col1, col2, col3 = st.columns(3)
67
-
68
  cos_value = round(float(cos.item()),4)
69
- st.write(cos_value)
70
  with col1:
71
  st.text(sent)
72
  with col2:
@@ -75,6 +70,7 @@ def app():
75
  st.text(cos_value)
76
 
77
  if button_clear:
 
78
  st.session_state.inputs.clear()
79
  del st.session_state["inputs"]
80
  st.session_state.inputs_index = 0
 
36
  button_clear = st.button("清空")
37
 
38
  def transformer(source_text, sentences):
39
+ #使用模型分别计算源字符串和对比字符串的embedding
40
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu")
41
  source_emb = model.encode(source_text, convert_to_tensor=True)
42
  sent_embs = model.encode(sentences, convert_to_tensor=True)
43
+ #计算源字符串和对比字符串embedding的cos值
44
  cos_sim = util.cos_sim(source_emb, sent_embs)
45
 
 
 
 
46
  cosin_dict = {}
47
 
48
  for i, cos in enumerate(torch.flatten(cos_sim)):
49
  cosin_dict[sentences[i]] = cos
50
 
51
+ #根据cos值降序排列
52
  sorted_dict = dict(sorted(cosin_dict.items(), key=lambda item: item[1],reverse = True))
53
  return sorted_dict
54
 
 
57
  # embeddings
58
  embeddings = transformer([source_text], st.session_state.inputs)
59
 
60
+ # 显示对比的字符串、进度条、cos值
 
 
 
61
  with st.container():
62
  for sent, cos in embeddings.items():
63
  col1, col2, col3 = st.columns(3)
 
64
  cos_value = round(float(cos.item()),4)
 
65
  with col1:
66
  st.text(sent)
67
  with col2:
 
70
  st.text(cos_value)
71
 
72
  if button_clear:
73
+ #清空对比字符串
74
  st.session_state.inputs.clear()
75
  del st.session_state["inputs"]
76
  st.session_state.inputs_index = 0