Walid Aissa commited on
Commit
80e614a
1 Parent(s): 064fc00

better wikipedia search

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import wikipediaapi as wk
 
5
  from transformers import (
6
  TokenClassificationPipeline,
7
  AutoModelForTokenClassification,
@@ -11,7 +12,7 @@ from transformers import (
11
  )
12
  from transformers.pipelines import AggregationStrategy
13
  import torch
14
-
15
  # =====[ DEFINE PIPELINE ]===== #
16
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
17
  def __init__(self, model, *args, **kwargs):
@@ -43,26 +44,36 @@ def keyphrases_extraction(text: str) -> str:
43
  def wikipedia_search(input: str) -> str:
44
  input = input.replace("\n", " ")
45
  keyphrases = keyphrases_extraction(input)
 
46
  wiki = wk.Wikipedia('en')
47
 
48
  try :
49
  #TODO: add better extraction and search
50
- keyphrase_index = 0
51
- page = wiki.page(keyphrases[keyphrase_index])
 
 
 
 
 
 
52
 
 
 
53
  while not ('.' in page.summary) or not page.exists():
54
- keyphrase_index += 1
55
- if keyphrase_index == len(keyphrases):
56
  raise Exception
57
- page = wiki.page(keyphrases[keyphrase_index])
58
- return page.summary
 
59
  except:
60
  return "I cannot answer this question"
61
 
62
  def answer_question(question):
63
 
64
  context = wikipedia_search(question)
65
- if context == "I cannot answer this question":
66
  return context
67
 
68
  # ======== Tokenize ========
@@ -99,6 +110,8 @@ def answer_question(question):
99
 
100
  start_scores = outputs.start_logits
101
  end_scores = outputs.end_logits
 
 
102
 
103
  # ======== Reconstruct Answer ========
104
  # Find the tokens with the highest `start` and `end` scores.
@@ -130,7 +143,7 @@ examples = [
130
  ["Where is the Eiffel Tower?"],
131
  ["What is the population of France?"]
132
  ]
133
-
134
  demo = gr.Interface(
135
  title = title,
136
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import wikipediaapi as wk
5
+ import wikipedia
6
  from transformers import (
7
  TokenClassificationPipeline,
8
  AutoModelForTokenClassification,
 
12
  )
13
  from transformers.pipelines import AggregationStrategy
14
  import torch
15
+ print("hello")
16
  # =====[ DEFINE PIPELINE ]===== #
17
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
18
  def __init__(self, model, *args, **kwargs):
 
44
  def wikipedia_search(input: str) -> str:
45
  input = input.replace("\n", " ")
46
  keyphrases = keyphrases_extraction(input)
47
+
48
  wiki = wk.Wikipedia('en')
49
 
50
  try :
51
  #TODO: add better extraction and search
52
+ if len(keyphrases) == 0:
53
+ return "Can you add more details to your question?"
54
+
55
+ query_suggestion = wikipedia.suggest(keyphrases[0])
56
+ if(query_suggestion != None):
57
+ results = wikipedia.search(query_suggestion)
58
+ else:
59
+ results = wikipedia.search(keyphrases[0])
60
 
61
+ index = 0
62
+ page = wiki.page(results[index])
63
  while not ('.' in page.summary) or not page.exists():
64
+ index += 1
65
+ if index == len(results):
66
  raise Exception
67
+ page = wiki.page(results[index])
68
+ return page.summary
69
+
70
  except:
71
  return "I cannot answer this question"
72
 
73
  def answer_question(question):
74
 
75
  context = wikipedia_search(question)
76
+ if (context == "I cannot answer this question") or (context == "Can you add more details to your question?"):
77
  return context
78
 
79
  # ======== Tokenize ========
 
110
 
111
  start_scores = outputs.start_logits
112
  end_scores = outputs.end_logits
113
+ print(start_scores)
114
+ print(end_scores)
115
 
116
  # ======== Reconstruct Answer ========
117
  # Find the tokens with the highest `start` and `end` scores.
 
143
  ["Where is the Eiffel Tower?"],
144
  ["What is the population of France?"]
145
  ]
146
+ print("hello")
147
  demo = gr.Interface(
148
  title = title,
149