IanRonk commited on
Commit
9f6db74
1 Parent(s): 7008ae2

add returning probabilities as well

Browse files
Files changed (2) hide show
  1. app.py +9 -2
  2. functions/model_infer.py +1 -1
app.py CHANGED
@@ -14,9 +14,16 @@ def pipeline(video_url):
14
  video_id = video_url.split("?v=")[-1]
15
  punctuated_text = punctuate(video_id)
16
  sentences = re.split(r"[\.\!\?]\s", punctuated_text)
17
- classification = predict_from_document(sentences)
18
  # return punctuated_text
19
- return [{"start": "12:05", "end": "12:52", "classification": str(classification)}]
 
 
 
 
 
 
 
20
 
21
 
22
  # print(pipeline("VL5M5ZihJK4"))
 
14
  video_id = video_url.split("?v=")[-1]
15
  punctuated_text = punctuate(video_id)
16
  sentences = re.split(r"[\.\!\?]\s", punctuated_text)
17
+ classification, probs = predict_from_document(sentences)
18
  # return punctuated_text
19
+ return [
20
+ {
21
+ "start": "12:05",
22
+ "end": "12:52",
23
+ "classification": str(classification),
24
+ "probabilities": probs,
25
+ }
26
+ ]
27
 
28
 
29
  # print(pipeline("VL5M5ZihJK4"))
functions/model_infer.py CHANGED
@@ -38,4 +38,4 @@ def predict_from_document(sentences):
38
  preprop = preprocess(sentences)
39
  prediction = model.predict(preprop)
40
  output = (prediction.flatten()[: len(sentences)] >= 0.5).astype(int)
41
- return output
 
38
  preprop = preprocess(sentences)
39
  prediction = model.predict(preprop)
40
  output = (prediction.flatten()[: len(sentences)] >= 0.5).astype(int)
41
+ return output, prediction.flatten()[: len(sentences)]