JuliaWolken commited on
Commit
3e36787
·
verified ·
1 Parent(s): 850c5cc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Deploy to Gradio.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/11013nKY0n_XCe50LERXzoqrHsKFwuRBl
8
+ """
9
+
10
+ chunks_data_path = 'WB_KB/chunks.csv')
11
+ model = SentenceTransformer('WB_KB/fine_tuned_model_with_triplets')
12
+ '
13
+ chunks_df = pd.read_csv(chunks_data_path)
14
+ original_chunks = chunks_df['Chunk'].tolist()
15
+
16
+
17
+ chunk_embeddings = model.encode(original_chunks, convert_to_tensor=True)
18
+
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained('DiTy/cross-encoder-russian-msmarco')
21
+ cross_encoder_model = AutoModelForSequenceClassification.from_pretrained('DiTy/cross-encoder-russian-msmarco')
22
+
23
+
24
+ def embed_texts(texts):
25
+ return model.encode(texts, convert_to_tensor=True)
26
+
27
+ def find_relevant_chunks(question_embedding, top_k=5):
28
+ cosine_similarities = cosine_similarity(question_embedding.cpu().numpy(), chunk_embeddings.cpu().numpy()).flatten()
29
+ num_candidates = top_k * 10 # Adjust to get more candidates for re-ranking
30
+ top_indices = cosine_similarities.argsort()[-num_candidates:][::-1]
31
+ return [original_chunks[i] for i in top_indices]
32
+
33
+
34
+ def re_rank(question, candidate_chunks):
35
+ inputs = tokenizer([question] * len(candidate_chunks), candidate_chunks, return_tensors='pt', padding=True, truncation=True, max_length=512)
36
+ with torch.no_grad():
37
+ scores = cross_encoder_model(**inputs).logits.squeeze()
38
+ ranked_indices = scores.argsort(descending=True)
39
+ return [candidate_chunks[i] for i in ranked_indices]
40
+
41
+
42
+ def find_relevant_chunks_with_reranking(question, top_k=5):
43
+ question_embedding = embed_texts([question])
44
+ candidate_chunks = find_relevant_chunks(question_embedding, top_k=top_k)
45
+ ranked_chunks = re_rank(question, candidate_chunks) if len(candidate_chunks) > 1 else candidate_chunks
46
+ return ranked_chunks[:top_k]
47
+
48
+
49
+ def answer_question(question):
50
+
51
+ if not question or len(question) < 10:
52
+ return "Пожалуйста, задайте вопрос. Количество символов должно превышать 10."
53
+
54
+
55
+ if not re.search(r'[а-яА-Я]', question):
56
+ return "Простите, на этом языке я пока не говорю. Попробуем еще раз?"
57
+
58
+
59
+ top_chunks = find_relevant_chunks_with_reranking(question, top_k=5)
60
+
61
+
62
+ if not top_chunks:
63
+ return "Ничего не нашлось. Я только учусь, сформулируйте вопрос иначе, пожалуйста"
64
+
65
+
66
+ return "\n\n".join([f"Answer {i+1}: {chunk}" for i, chunk in enumerate(top_chunks)])
67
+
68
+ # Set up Gradio interface
69
+ iface = gr.Interface(
70
+ fn=answer_question,
71
+ inputs="text",
72
+ outputs="text",
73
+ title="Question Answering Model",
74
+ description="Здравствуйте! Задайте мне вопрос на русском о работе пунктов выдачи WB, и я постараюсь найти самые лучшие ответы."
75
+ )
76
+
77
+ # Launch the Gradio interface with shareable link
78
+ iface.launch(share=True)
79
+
80
+ """URL: https://54bac118d6e06c6b04.gradio.live
81
+
82
+ Здравствуйте! Эта ссылка будет доступна в течение всего трех дней. Если нам потребуется больше времени, пожалуйста, сообщите! Я запущу приложение заново. Мой телеграм для связи @juliawolkenstein
83
+
84
+ """