Update app.py
Browse files
app.py
CHANGED
@@ -311,6 +311,7 @@ def chat_bert_context(question, history):
|
|
311 |
|
312 |
#-------------------------------------Bi-BERT-Encoder------------------------------------------#
|
313 |
MAX_LENGTH = 128
|
|
|
314 |
# Define function for mean-pooling
|
315 |
def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
|
316 |
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
|
@@ -372,12 +373,13 @@ def chat_bi_bert(question):
|
|
372 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
373 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
374 |
answer = df['answer'].iloc[top_indice]
|
|
|
375 |
return answer
|
376 |
|
377 |
|
378 |
|
379 |
#-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
|
380 |
-
|
381 |
|
382 |
#Define class for CrossEncoderBert
|
383 |
class CrossEncoderBert(torch.nn.Module):
|
@@ -418,9 +420,9 @@ def chat_cross_bert(question):
|
|
418 |
|
419 |
# Process scores for finetuned model
|
420 |
scores = ce_scores.cpu().numpy()
|
421 |
-
|
422 |
# print(f"{corpus[scores_ix]}")
|
423 |
-
return corpus[
|
424 |
|
425 |
# gradio part
|
426 |
def echo(message, history, model):
|
|
|
311 |
|
312 |
#-------------------------------------Bi-BERT-Encoder------------------------------------------#
|
313 |
MAX_LENGTH = 128
|
314 |
+
inverted_answer = dict(enumerate(df.answer.tolist()))
|
315 |
# Define function for mean-pooling
|
316 |
def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
|
317 |
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
|
|
|
373 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
374 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
375 |
answer = df['answer'].iloc[top_indice]
|
376 |
+
answer = inverted_answer[top_indice]
|
377 |
return answer
|
378 |
|
379 |
|
380 |
|
381 |
#-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
|
382 |
+
|
383 |
|
384 |
#Define class for CrossEncoderBert
|
385 |
class CrossEncoderBert(torch.nn.Module):
|
|
|
420 |
|
421 |
# Process scores for finetuned model
|
422 |
scores = ce_scores.cpu().numpy()
|
423 |
+
ix = np.argmax(scores)
|
424 |
# print(f"{corpus[scores_ix]}")
|
425 |
+
return corpus[ix]
|
426 |
|
427 |
# gradio part
|
428 |
def echo(message, history, model):
|