Spaces:
Running
Running
kovacsvi
commited on
Commit
·
f295362
1
Parent(s):
08a8f54
updated model selection logic
Browse files- interfaces/cap.py +11 -34
interfaces/cap.py
CHANGED
@@ -47,43 +47,20 @@ def check_huggingface_path(checkpoint_path: str):
|
|
47 |
def build_huggingface_path(language: str, domain: str):
|
48 |
language = language.lower()
|
49 |
base_path = "xlm-roberta-large"
|
50 |
-
|
51 |
-
if language == "english" and (domain == "media" or domain == "legislative"):
|
52 |
-
lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v4"
|
53 |
-
return lang_domain_path
|
54 |
-
else:
|
55 |
-
lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
|
56 |
-
|
57 |
lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
"
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
lang_domain_table = pd.read_csv("language_domain_models.csv")
|
68 |
-
lang_domain_table["language"] = lang_domain_table["language"].str.lower()
|
69 |
-
lang_domain_table.columns = lang_domain_table.columns.str.lower()
|
70 |
-
# get the row for the language and them get the value from the domain column
|
71 |
-
row = lang_domain_table[(lang_domain_table["language"] == language)]
|
72 |
-
tmp = row.get(domain)
|
73 |
-
if not tmp.empty:
|
74 |
-
value = tmp.iloc[0]
|
75 |
-
except (AttributeError, FileNotFoundError):
|
76 |
-
value = None
|
77 |
-
|
78 |
-
if language == 'english':
|
79 |
-
model_path = lang_path
|
80 |
-
else:
|
81 |
-
model_path = "poltextlab/xlm-roberta-large-pooled-cap"
|
82 |
-
|
83 |
-
if check_huggingface_path(model_path):
|
84 |
-
return model_path
|
85 |
else:
|
86 |
-
|
|
|
|
|
87 |
|
88 |
|
89 |
def predict(text, model_id, tokenizer_id):
|
|
|
47 |
def build_huggingface_path(language: str, domain: str):
|
48 |
language = language.lower()
|
49 |
base_path = "xlm-roberta-large"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
|
51 |
|
52 |
+
# some custom mapping
|
53 |
+
if language in ["english"]:
|
54 |
+
if domain in ["media", "legislative"]:
|
55 |
+
return f"poltextlab/{base_path}-{language}-{domain}-cap-v4"
|
56 |
+
elif domain in ["social"]:
|
57 |
+
return f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
|
58 |
+
return lang_path
|
59 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
else:
|
61 |
+
if domain in ["social"]:
|
62 |
+
return f"poltextlab/{base_path}-{domain}-cap-v3"
|
63 |
+
return "poltextlab/xlm-roberta-large-pooled-cap-v3"
|
64 |
|
65 |
|
66 |
def predict(text, model_id, tokenizer_id):
|