kovacsvi commited on
Commit
f295362
·
1 Parent(s): 08a8f54

updated model selection logic

Browse files
Files changed (1) hide show
  1. interfaces/cap.py +11 -34
interfaces/cap.py CHANGED
@@ -47,43 +47,20 @@ def check_huggingface_path(checkpoint_path: str):
47
  def build_huggingface_path(language: str, domain: str):
48
  language = language.lower()
49
  base_path = "xlm-roberta-large"
50
-
51
- if language == "english" and (domain == "media" or domain == "legislative"):
52
- lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v4"
53
- return lang_domain_path
54
- else:
55
- lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
56
-
57
  lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
58
 
59
- path_map = {
60
- "L": lang_path,
61
- "L-D": lang_domain_path,
62
- "X": lang_domain_path,
63
- }
64
- value = None
65
-
66
- try:
67
- lang_domain_table = pd.read_csv("language_domain_models.csv")
68
- lang_domain_table["language"] = lang_domain_table["language"].str.lower()
69
- lang_domain_table.columns = lang_domain_table.columns.str.lower()
70
- # get the row for the language and them get the value from the domain column
71
- row = lang_domain_table[(lang_domain_table["language"] == language)]
72
- tmp = row.get(domain)
73
- if not tmp.empty:
74
- value = tmp.iloc[0]
75
- except (AttributeError, FileNotFoundError):
76
- value = None
77
-
78
- if language == 'english':
79
- model_path = lang_path
80
- else:
81
- model_path = "poltextlab/xlm-roberta-large-pooled-cap"
82
-
83
- if check_huggingface_path(model_path):
84
- return model_path
85
  else:
86
- return "poltextlab/xlm-roberta-large-pooled-cap"
 
 
87
 
88
 
89
  def predict(text, model_id, tokenizer_id):
 
47
  def build_huggingface_path(language: str, domain: str):
48
  language = language.lower()
49
  base_path = "xlm-roberta-large"
 
 
 
 
 
 
 
50
  lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
51
 
52
+ # some custom mapping
53
+ if language in ["english"]:
54
+ if domain in ["media", "legislative"]:
55
+ return f"poltextlab/{base_path}-{language}-{domain}-cap-v4"
56
+ elif domain in ["social"]:
57
+ return f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
58
+ return lang_path
59
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  else:
61
+ if domain in ["social"]:
62
+ return f"poltextlab/{base_path}-{domain}-cap-v3"
63
+ return "poltextlab/xlm-roberta-large-pooled-cap-v3"
64
 
65
 
66
  def predict(text, model_id, tokenizer_id):