StKirill commited on
Commit
763c5e6
·
verified ·
1 Parent(s): 809df88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -311,6 +311,7 @@ def chat_bert_context(question, history):
311
 
312
  #-------------------------------------Bi-BERT-Encoder------------------------------------------#
313
  MAX_LENGTH = 128
 
314
  # Define function for mean-pooling
315
  def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
316
  in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
@@ -372,12 +373,13 @@ def chat_bi_bert(question):
372
  cosine_similarities = cosine_similarity([question], question_embeds).flatten()
373
  top_indice = np.argmax(cosine_similarities, axis=0)
374
  answer = df['answer'].iloc[top_indice]
 
375
  return answer
376
 
377
 
378
 
379
  #-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
380
- inverted_answer = dict(enumerate(df.answer.tolist()))
381
 
382
  #Define class for CrossEncoderBert
383
  class CrossEncoderBert(torch.nn.Module):
@@ -418,9 +420,9 @@ def chat_cross_bert(question):
418
 
419
  # Process scores for finetuned model
420
  scores = ce_scores.cpu().numpy()
421
- scores_ix = np.argmax(scores)
422
  # print(f"{corpus[scores_ix]}")
423
- return corpus[scores_ix]
424
 
425
  # gradio part
426
  def echo(message, history, model):
 
311
 
312
  #-------------------------------------Bi-BERT-Encoder------------------------------------------#
313
  MAX_LENGTH = 128
314
+ inverted_answer = dict(enumerate(df.answer.tolist()))
315
  # Define function for mean-pooling
316
  def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
317
  in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
 
373
  cosine_similarities = cosine_similarity([question], question_embeds).flatten()
374
  top_indice = np.argmax(cosine_similarities, axis=0)
375
  answer = df['answer'].iloc[top_indice]
376
+ answer = inverted_answer[top_indice]
377
  return answer
378
 
379
 
380
 
381
  #-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
382
+
383
 
384
  #Define class for CrossEncoderBert
385
  class CrossEncoderBert(torch.nn.Module):
 
420
 
421
  # Process scores for finetuned model
422
  scores = ce_scores.cpu().numpy()
423
+ ix = np.argmax(scores)
424
  # print(f"{corpus[scores_ix]}")
425
+ return corpus[ix]
426
 
427
  # gradio part
428
  def echo(message, history, model):