Spaces:
Running
on
T4
Running
on
T4
Julien Simon
commited on
Commit
·
59b8055
1
Parent(s):
0d6c6e0
Update code
Browse files
app.py
CHANGED
@@ -1,70 +1,74 @@
|
|
1 |
-
import nltk
|
2 |
-
import pickle
|
3 |
-
import pandas as pd
|
4 |
import gradio as gr
|
|
|
5 |
import numpy as np
|
|
|
|
|
6 |
from sentence_transformers import SentenceTransformer, util
|
7 |
from transformers import pipeline
|
8 |
-
from librosa import load, resample
|
9 |
|
10 |
# Constants
|
11 |
-
filename =
|
12 |
|
13 |
-
model_name =
|
14 |
max_sequence_length = 512
|
15 |
-
embeddings_filename =
|
16 |
-
asr_model =
|
17 |
|
18 |
# Load corpus
|
19 |
df = pd.read_csv(filename)
|
20 |
df.drop_duplicates(inplace=True)
|
21 |
-
print(f
|
22 |
|
23 |
-
nltk.download(
|
24 |
|
25 |
corpus = []
|
26 |
sentence_count = []
|
27 |
for _, row in df.iterrows():
|
28 |
# We're interested in the 'mdna' column: 'Management discussion and analysis'
|
29 |
-
sentences = nltk.tokenize.sent_tokenize(str(row[
|
30 |
sentence_count.append(len(sentences))
|
31 |
-
for _,s in enumerate(sentences):
|
32 |
corpus.append(s)
|
33 |
-
print(f
|
34 |
|
35 |
# Load pre-embedded corpus
|
36 |
-
corpus_embeddings = np.load(embeddings_filename)[
|
37 |
-
print(f
|
38 |
|
39 |
# Load embedding model
|
40 |
model = SentenceTransformer(model_name)
|
41 |
model.max_seq_length = max_sequence_length
|
42 |
|
43 |
# Load speech to text model
|
44 |
-
asr = pipeline(
|
|
|
|
|
|
|
45 |
|
46 |
def find_sentences(query, hits):
|
47 |
query_embedding = model.encode(query)
|
48 |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
|
49 |
hits = hits[0]
|
50 |
|
51 |
-
output = pd.DataFrame(
|
|
|
|
|
52 |
for hit in hits:
|
53 |
-
corpus_id = hit[
|
54 |
# Find source document based on sentence index
|
55 |
count = 0
|
56 |
for idx, c in enumerate(sentence_count):
|
57 |
-
count+=c
|
58 |
-
if
|
59 |
continue
|
60 |
else:
|
61 |
doc = df.iloc[idx]
|
62 |
-
new_row = {
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
}
|
69 |
output = output.append(new_row, ignore_index=True)
|
70 |
break
|
@@ -72,43 +76,70 @@ def find_sentences(query, hits):
|
|
72 |
|
73 |
|
74 |
def process(input_selection, query, filepath, hits):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
83 |
|
84 |
# Gradio inputs
|
85 |
-
buttons
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Gradio outputs
|
91 |
-
speech_query = gr.Textbox(type=
|
92 |
-
results
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
96 |
|
97 |
iface = gr.Interface(
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
],
|
112 |
-
allow_flagging=False
|
113 |
)
|
114 |
iface.launch()
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import nltk
|
3 |
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from librosa import load, resample
|
6 |
from sentence_transformers import SentenceTransformer, util
|
7 |
from transformers import pipeline
|
|
|
8 |
|
9 |
# Constants
|
10 |
+
filename = "df10k_SP500_2020.csv.zip"
|
11 |
|
12 |
+
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
|
13 |
max_sequence_length = 512
|
14 |
+
embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz"
|
15 |
+
asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en"
|
16 |
|
17 |
# Load corpus
|
18 |
df = pd.read_csv(filename)
|
19 |
df.drop_duplicates(inplace=True)
|
20 |
+
print(f"Number of documents: {len(df)}")
|
21 |
|
22 |
+
nltk.download("punkt")
|
23 |
|
24 |
corpus = []
|
25 |
sentence_count = []
|
26 |
for _, row in df.iterrows():
|
27 |
# We're interested in the 'mdna' column: 'Management discussion and analysis'
|
28 |
+
sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english")
|
29 |
sentence_count.append(len(sentences))
|
30 |
+
for _, s in enumerate(sentences):
|
31 |
corpus.append(s)
|
32 |
+
print(f"Number of sentences: {len(corpus)}")
|
33 |
|
34 |
# Load pre-embedded corpus
|
35 |
+
corpus_embeddings = np.load(embeddings_filename)["arr_0"]
|
36 |
+
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")
|
37 |
|
38 |
# Load embedding model
|
39 |
model = SentenceTransformer(model_name)
|
40 |
model.max_seq_length = max_sequence_length
|
41 |
|
42 |
# Load speech to text model
|
43 |
+
asr = pipeline(
|
44 |
+
"automatic-speech-recognition", model=asr_model, feature_extractor=asr_model
|
45 |
+
)
|
46 |
+
|
47 |
|
48 |
def find_sentences(query, hits):
|
49 |
query_embedding = model.encode(query)
|
50 |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
|
51 |
hits = hits[0]
|
52 |
|
53 |
+
output = pd.DataFrame(
|
54 |
+
columns=["Ticker", "Form type", "Filing date", "Text", "Score"]
|
55 |
+
)
|
56 |
for hit in hits:
|
57 |
+
corpus_id = hit["corpus_id"]
|
58 |
# Find source document based on sentence index
|
59 |
count = 0
|
60 |
for idx, c in enumerate(sentence_count):
|
61 |
+
count += c
|
62 |
+
if corpus_id > count - 1:
|
63 |
continue
|
64 |
else:
|
65 |
doc = df.iloc[idx]
|
66 |
+
new_row = {
|
67 |
+
"Ticker": doc["ticker"],
|
68 |
+
"Form type": doc["form_type"],
|
69 |
+
"Filing date": doc["filing_date"],
|
70 |
+
"Text": corpus[corpus_id][:80],
|
71 |
+
"Score": "{:.2f}".format(hit["score"]),
|
72 |
}
|
73 |
output = output.append(new_row, ignore_index=True)
|
74 |
break
|
|
|
76 |
|
77 |
|
78 |
def process(input_selection, query, filepath, hits):
|
79 |
+
if input_selection == "speech":
|
80 |
+
speech, sampling_rate = load(filepath)
|
81 |
+
if sampling_rate != 16000:
|
82 |
+
speech = resample(speech, orig_sr=sampling_rate, target_sr=16000)
|
83 |
+
text = asr(speech)["text"]
|
84 |
+
else:
|
85 |
+
text = query
|
86 |
+
return text, find_sentences(text, hits)
|
87 |
+
|
88 |
|
89 |
# Gradio inputs
|
90 |
+
buttons = gr.Radio(
|
91 |
+
["text", "speech"], type="value", value="speech", label="Input selection"
|
92 |
+
)
|
93 |
+
text_query = gr.Textbox(
|
94 |
+
lines=1,
|
95 |
+
label="Text input",
|
96 |
+
value="The company is under investigation by tax authorities for potential fraud.",
|
97 |
+
)
|
98 |
+
mic = gr.Audio(
|
99 |
+
source="microphone", type="filepath", label="Speech input", optional=True
|
100 |
+
)
|
101 |
+
slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits")
|
102 |
|
103 |
# Gradio outputs
|
104 |
+
speech_query = gr.Textbox(type="text", label="Query string")
|
105 |
+
results = gr.Dataframe(
|
106 |
+
type="pandas",
|
107 |
+
headers=["Ticker", "Form type", "Filing date", "Text", "Score"],
|
108 |
+
label="Query results",
|
109 |
+
)
|
110 |
|
111 |
iface = gr.Interface(
|
112 |
+
theme="huggingface",
|
113 |
+
description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80",
|
114 |
+
fn=process,
|
115 |
+
inputs=[buttons, text_query, mic, slider],
|
116 |
+
outputs=[speech_query, results],
|
117 |
+
examples=[
|
118 |
+
[
|
119 |
+
"speech",
|
120 |
+
"Nos ventes internationales ont significativement augmenté.",
|
121 |
+
"sales_16k_fr.wav",
|
122 |
+
3,
|
123 |
+
],
|
124 |
+
[
|
125 |
+
"speech",
|
126 |
+
"Le prix de l'énergie pourrait avoir un impact négatif dans le futur.",
|
127 |
+
"energy_16k_fr.wav",
|
128 |
+
3,
|
129 |
+
],
|
130 |
+
[
|
131 |
+
"speech",
|
132 |
+
"El precio de la energía podría tener un impacto negativo en el futuro.",
|
133 |
+
"energy_24k_es.wav",
|
134 |
+
3,
|
135 |
+
],
|
136 |
+
[
|
137 |
+
"speech",
|
138 |
+
"Mehrere Steuerbehörden untersuchen unser Unternehmen.",
|
139 |
+
"tax_24k_de.wav",
|
140 |
+
3,
|
141 |
+
],
|
142 |
],
|
143 |
+
allow_flagging=False,
|
144 |
)
|
145 |
iface.launch()
|
dummy.wav
DELETED
File without changes
|