gagan3012 commited on
Commit
515949b
·
1 Parent(s): 824d37a

added params

Browse files
requirements.txt CHANGED
@@ -6,6 +6,7 @@ torch==1.9.0+cu111
6
  dagshub==0.1.6
7
  pandas==1.2.4
8
  rouge_score
 
9
 
10
  # external requirements
11
  click
 
6
  dagshub==0.1.6
7
  pandas==1.2.4
8
  rouge_score
9
+ yaml
10
 
11
  # external requirements
12
  click
src/models/evaluate_model.py CHANGED
@@ -14,7 +14,7 @@ def evaluate_model():
14
 
15
  test_df = pd.load_csv('data/processed/test.csv')
16
  model = Summarization()
17
- model.load_model(model_dir=params['model_dir'])
18
  results = model.evaluate(test_df=test_df, metrics=params['metric'])
19
 
20
  with dagshub.dagshub_logger() as logger:
 
14
 
15
  test_df = pd.load_csv('data/processed/test.csv')
16
  model = Summarization()
17
+ model.load_model(model_type=params['model_type'], model_dir=params['model_dir'])
18
  results = model.evaluate(test_df=test_df, metrics=params['metric'])
19
 
20
  with dagshub.dagshub_logger() as logger:
src/models/predict_model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from .model import Summarization
2
  import pandas as pd
3
 
@@ -6,8 +8,12 @@ def predict_model(text):
6
  """
7
  Predict the summary of the given text.
8
  """
 
 
 
 
9
  model = Summarization()
10
- model.load_model()
11
  pre_summary = model.predict(text)
12
  return pre_summary
13
 
 
1
+ import yaml
2
+
3
  from .model import Summarization
4
  import pandas as pd
5
 
 
8
  """
9
  Predict the summary of the given text.
10
  """
11
+ with open("params.yml") as f:
12
+ params = yaml.safe_load(f)
13
+
14
+
15
  model = Summarization()
16
+ model.load_model(model_type=params['model_type'], model_dir=params['model_dir'])
17
  pre_summary = model.predict(text)
18
  return pre_summary
19