Spaces:
Sleeping
Sleeping
Abhipsha Das
commited on
add files
Browse files- data/databases/README.md +32 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/eval/__init__.py +0 -0
- src/eval/metrics.py +87 -0
- src/processing/__init__.py +0 -0
- src/processing/__pycache__/__init__.cpython-311.pyc +0 -0
- src/processing/__pycache__/extractions.cpython-311.pyc +0 -0
- src/processing/__pycache__/generate.cpython-311.pyc +0 -0
- src/processing/extractions.py +65 -0
- src/processing/generate.py +226 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- src/utils/__pycache__/utils.cpython-311.pyc +0 -0
- src/utils/utils.py +155 -0
data/databases/README.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- This folder contains all the SQL databases for the different processed data along with their raw data.
|
2 |
+
|
3 |
+
- The databases are named after the arXiv category and the format of the generated data.
|
4 |
+
|
5 |
+
Each file in this folder is a database containing 2 tables:
|
6 |
+
- **papers**
|
7 |
+
|
8 |
+
The papers data from the `raw` folder that was fed to the model.
|
9 |
+
|
10 |
+
SCHEMA:
|
11 |
+
- paper_id TEXT PRIMARY KEY,
|
12 |
+
- abstract TEXT,
|
13 |
+
- authors TEXT,
|
14 |
+
- primary_category TEXT,
|
15 |
+
- url TEXT,
|
16 |
+
- updated_on TEXT,
|
17 |
+
- sentence_count INTEGER
|
18 |
+
|
19 |
+
- **predictions**
|
20 |
+
|
21 |
+
The corresponding model generations stored in the `results` folder.
|
22 |
+
|
23 |
+
SCHEMA:
|
24 |
+
- id INTEGER PRIMARY KEY AUTOINCREMENT,
|
25 |
+
- paper_id TEXT,
|
26 |
+
- sentence_index INTEGER,
|
27 |
+
- tag_type TEXT,
|
28 |
+
- concept TEXT,
|
29 |
+
- FOREIGN KEY (paper_id) REFERENCES papers(paper_id)
|
30 |
+
|
31 |
+
|
32 |
+
To query any database, open SQLite in your terminal and specify the database name.
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (155 Bytes). View file
|
|
src/eval/__init__.py
ADDED
File without changes
|
src/eval/metrics.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
|
4 |
+
def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]:
|
5 |
+
"""
|
6 |
+
Returns true positives, false positives, and false negatives for one example
|
7 |
+
If union is True, then disregards the type of the tag and only considers the union of all tags
|
8 |
+
"""
|
9 |
+
n_tp = 0
|
10 |
+
n_fp = 0
|
11 |
+
n_fn = 0
|
12 |
+
if union:
|
13 |
+
gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases)
|
14 |
+
pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases)
|
15 |
+
n_tp = len(gold_phrases & pred_phrases)
|
16 |
+
n_fp = len(pred_phrases - gold_phrases)
|
17 |
+
n_fn = len(gold_phrases - pred_phrases)
|
18 |
+
return n_tp, n_fp, n_fn
|
19 |
+
|
20 |
+
for tag in set(gold.keys()).union(pred.keys()):
|
21 |
+
gold_phrases = set(gold.get(tag, []))
|
22 |
+
pred_phrases = set(pred.get(tag, []))
|
23 |
+
|
24 |
+
n_tp += len(gold_phrases & pred_phrases)
|
25 |
+
n_fp += len(pred_phrases - gold_phrases)
|
26 |
+
n_fn += len(gold_phrases - pred_phrases)
|
27 |
+
|
28 |
+
return n_tp, n_fp, n_fn
|
29 |
+
|
30 |
+
|
31 |
+
def compute_metrics(running_time, pred_times, runtype, eval_metrics=None):
|
32 |
+
metrics = {}
|
33 |
+
metrics["avg_pred_response_time_per_sentence"] = (
|
34 |
+
round(sum(pred_times) / len(pred_times), 4) if pred_times else 0
|
35 |
+
)
|
36 |
+
metrics["total_time"] = round(running_time, 4)
|
37 |
+
|
38 |
+
if runtype == "eval" and eval_metrics is not None:
|
39 |
+
n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics
|
40 |
+
|
41 |
+
precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4)
|
42 |
+
recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4)
|
43 |
+
f1 = round(
|
44 |
+
(
|
45 |
+
2 * (precision * recall) / (precision + recall)
|
46 |
+
if (precision + recall) > 0
|
47 |
+
else 0
|
48 |
+
),
|
49 |
+
4,
|
50 |
+
)
|
51 |
+
union_precision = round(
|
52 |
+
(
|
53 |
+
n_tp_union / (n_tp_union + n_fp_union)
|
54 |
+
if (n_tp_union + n_fp_union) > 0
|
55 |
+
else 0
|
56 |
+
),
|
57 |
+
4,
|
58 |
+
)
|
59 |
+
union_recall = round(
|
60 |
+
(
|
61 |
+
n_tp_union / (n_tp_union + n_fn_union)
|
62 |
+
if (n_tp_union + n_fn_union) > 0
|
63 |
+
else 0
|
64 |
+
),
|
65 |
+
4,
|
66 |
+
)
|
67 |
+
union_f1 = round(
|
68 |
+
(
|
69 |
+
2 * (union_precision * union_recall) / (union_precision + union_recall)
|
70 |
+
if (union_precision + union_recall) > 0
|
71 |
+
else 0
|
72 |
+
),
|
73 |
+
4,
|
74 |
+
)
|
75 |
+
|
76 |
+
metrics.update(
|
77 |
+
{
|
78 |
+
"precision": precision,
|
79 |
+
"recall": recall,
|
80 |
+
"f1": f1,
|
81 |
+
"union_precision": union_precision,
|
82 |
+
"union_recall": union_recall,
|
83 |
+
"union_f1": union_f1,
|
84 |
+
}
|
85 |
+
)
|
86 |
+
|
87 |
+
return metrics
|
src/processing/__init__.py
ADDED
File without changes
|
src/processing/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (166 Bytes). View file
|
|
src/processing/__pycache__/extractions.cpython-311.pyc
ADDED
Binary file (4.36 kB). View file
|
|
src/processing/__pycache__/generate.cpython-311.pyc
ADDED
Binary file (9.78 kB). View file
|
|
src/processing/extractions.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from bs4 import BeautifulSoup
|
5 |
+
from collections import defaultdict
|
6 |
+
from typing import Dict, List
|
7 |
+
|
8 |
+
|
9 |
+
# TODO: review the functions here
|
10 |
+
def extract_all_tagged_phrases(text: str) -> Dict[str, List[str]]:
|
11 |
+
soup = BeautifulSoup(text, "html.parser")
|
12 |
+
tagged_phrases = defaultdict(list)
|
13 |
+
|
14 |
+
for tag in soup.find_all(True):
|
15 |
+
if tag.name:
|
16 |
+
# Clean and process the text
|
17 |
+
full_text = " ".join(tag.stripped_strings)
|
18 |
+
full_text = re.sub(r"\s+", " ", full_text.strip())
|
19 |
+
full_text = re.sub(r'(?<!\\)\\(?!["\\])', r"\\\\", full_text)
|
20 |
+
full_text = full_text.replace('"', '\\"')
|
21 |
+
|
22 |
+
if full_text: # Only add non-empty strings
|
23 |
+
tagged_phrases[tag.name].append(full_text)
|
24 |
+
|
25 |
+
# Remove duplicates while preserving order
|
26 |
+
return {
|
27 |
+
tag: list(dict.fromkeys(phrases)) for tag, phrases in tagged_phrases.items()
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def extract_prediction(schema: dict, prediction: str, kind: str = "json") -> dict:
|
32 |
+
pred = {}
|
33 |
+
if kind == "json":
|
34 |
+
json_match = re.search(r"\{[\s\S]+\}", prediction)
|
35 |
+
if json_match:
|
36 |
+
json_str = json_match.group(0)
|
37 |
+
json_str = re.sub(r"(\w+)-\$?\\?(\w+)\$?", r"\1-\2", json_str)
|
38 |
+
json_str = json_str.replace('\\"', '"')
|
39 |
+
json_str = re.sub(r'}\s*"', '}, "', json_str)
|
40 |
+
json_str = re.sub(r']\s*"', '], "', json_str)
|
41 |
+
try:
|
42 |
+
pred = json.loads(json_str)
|
43 |
+
except json.JSONDecodeError as e:
|
44 |
+
logging.warning(f"Failed to parse JSON: {json_str}")
|
45 |
+
logging.warning(f"Error: {str(e)}")
|
46 |
+
|
47 |
+
try:
|
48 |
+
json_str = re.sub(r",\s*([}\]])", r"\1", json_str)
|
49 |
+
json_str = re.sub(r"(?<![\w'])'|'(?![\w'])", '"', json_str)
|
50 |
+
pred = json.loads(json_str)
|
51 |
+
except json.JSONDecodeError:
|
52 |
+
logging.error(
|
53 |
+
f"Failed to parse JSON even after attempted fixes: {json_str}"
|
54 |
+
)
|
55 |
+
elif kind == "readable":
|
56 |
+
match = re.findall(
|
57 |
+
rf'^({"|".join(list(schema.keys()))}): (.+)$',
|
58 |
+
prediction,
|
59 |
+
flags=re.MULTILINE,
|
60 |
+
)
|
61 |
+
pred = {tag: values.split(", ") for tag, values in match}
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Invalid kind: {kind}")
|
64 |
+
|
65 |
+
return pred
|
src/processing/generate.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import re
|
4 |
+
|
5 |
+
# import spacy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from config import (
|
9 |
+
DEFAULT_FEW_SHOT_NUM,
|
10 |
+
DEFAULT_FEW_SHOT_SELECTION,
|
11 |
+
DEFAULT_TEMPERATURE,
|
12 |
+
DEFAULT_TOP_P,
|
13 |
+
DEFAULT_KIND,
|
14 |
+
)
|
15 |
+
from typing import List, Dict, Tuple, Union
|
16 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
17 |
+
|
18 |
+
from .extractions import extract_all_tagged_phrases
|
19 |
+
|
20 |
+
# nlp = spacy.load("en_core_web_sm")
|
21 |
+
|
22 |
+
|
23 |
+
# TODO: run with constituency tests
|
24 |
+
# TODO: review instruction and system level prompt (currently they are repetitive)
|
25 |
+
def get_sentences(text: str) -> List[str]:
|
26 |
+
# TODO: spacy splitting results in unequal lengths
|
27 |
+
# doc = nlp(text)
|
28 |
+
# sentences = [sent.text.strip() for sent in doc.sents]
|
29 |
+
# sentences = [s for s in sentences if s]
|
30 |
+
# return sentences
|
31 |
+
|
32 |
+
return text.split(". ")
|
33 |
+
|
34 |
+
|
35 |
+
def format_instance(sentence: str, extraction: Union[str, None]) -> str:
|
36 |
+
return "".join(
|
37 |
+
[
|
38 |
+
f"Sentence: {sentence}\n",
|
39 |
+
(
|
40 |
+
f"Extractions:\n{extraction}\n"
|
41 |
+
if extraction is not None
|
42 |
+
else f"Extractions:\n"
|
43 |
+
),
|
44 |
+
]
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def generate_instructions(schema: dict, kind: str = DEFAULT_KIND) -> str:
|
49 |
+
instruction_parts = [
|
50 |
+
"The following schema is provided to tag the title and abstract of a given scientific paper as shown in the examples:\n"
|
51 |
+
]
|
52 |
+
if kind == "json":
|
53 |
+
instruction_parts.append(f"{json.dumps(schema, indent=2)}\n\n")
|
54 |
+
elif kind == "readable":
|
55 |
+
readable_schema = ""
|
56 |
+
for tag, description in schema.items():
|
57 |
+
readable_schema += f"{tag}: {description}\n"
|
58 |
+
instruction_parts.append(f"{readable_schema}\n")
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Invalid kind: {kind}")
|
61 |
+
|
62 |
+
return "".join(instruction_parts)
|
63 |
+
|
64 |
+
|
65 |
+
def generate_demonstrations(
|
66 |
+
examples: List[dict],
|
67 |
+
kind: str = DEFAULT_KIND,
|
68 |
+
num_examples: int = DEFAULT_FEW_SHOT_NUM,
|
69 |
+
selection: str = DEFAULT_FEW_SHOT_SELECTION,
|
70 |
+
) -> str:
|
71 |
+
demonstration_parts = []
|
72 |
+
for example in examples:
|
73 |
+
sentences = get_sentences(example["abstract"])
|
74 |
+
tagged_sentences = get_sentences(example["tagged_abstract"])
|
75 |
+
paired_sentences = list(zip(sentences, tagged_sentences, strict=True))
|
76 |
+
|
77 |
+
if selection == "random":
|
78 |
+
selected_pairs = random.sample(
|
79 |
+
paired_sentences, min(num_examples, len(paired_sentences))
|
80 |
+
)
|
81 |
+
elif selection == "first":
|
82 |
+
selected_pairs = paired_sentences[:num_examples]
|
83 |
+
elif selection == "last":
|
84 |
+
selected_pairs = paired_sentences[-num_examples:]
|
85 |
+
elif selection == "middle":
|
86 |
+
start = max(0, (len(paired_sentences) - num_examples) // 2)
|
87 |
+
selected_pairs = paired_sentences[start : start + num_examples]
|
88 |
+
elif selection == "distributed":
|
89 |
+
step = max(1, len(paired_sentences) // num_examples)
|
90 |
+
selected_pairs = paired_sentences[::step][:num_examples]
|
91 |
+
elif selection == "longest":
|
92 |
+
selected_pairs = sorted(
|
93 |
+
paired_sentences, key=lambda x: len(x[0]), reverse=True
|
94 |
+
)[:num_examples]
|
95 |
+
elif selection == "shortest":
|
96 |
+
selected_pairs = sorted(paired_sentences, key=lambda x: len(x[0]))[
|
97 |
+
:num_examples
|
98 |
+
]
|
99 |
+
else:
|
100 |
+
raise ValueError(f"Invalid selection method: {selection}")
|
101 |
+
|
102 |
+
for sentence, tagged_sentence in selected_pairs:
|
103 |
+
tag_to_phrase = extract_all_tagged_phrases(tagged_sentence)
|
104 |
+
if kind == "json":
|
105 |
+
extractions = f"{json.dumps(tag_to_phrase, indent=2)}\n"
|
106 |
+
elif kind == "readable":
|
107 |
+
extractions = "".join(
|
108 |
+
f"{tag}: {', '.join(phrase)}\n"
|
109 |
+
for tag, phrase in tag_to_phrase.items()
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
raise ValueError(f"Invalid kind: {kind}")
|
113 |
+
|
114 |
+
demonstration_parts.append(format_instance(sentence, extractions))
|
115 |
+
|
116 |
+
return "".join(demonstration_parts)
|
117 |
+
|
118 |
+
|
119 |
+
def generate_prefix(instructions: str, demonstrations: str) -> str:
|
120 |
+
return f"{instructions}" f"{demonstrations}"
|
121 |
+
|
122 |
+
|
123 |
+
def generate_prediction(
|
124 |
+
model,
|
125 |
+
tokenizer,
|
126 |
+
prefix: str,
|
127 |
+
input: str,
|
128 |
+
kind: str,
|
129 |
+
system_prompt: str = f"You are an assistant who tags papers according to given schema and "
|
130 |
+
"only returns the tagged phrases in the format as provided in the examples "
|
131 |
+
"without repeating anything else.",
|
132 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
133 |
+
top_p: float = DEFAULT_TOP_P,
|
134 |
+
) -> str:
|
135 |
+
prompt = prefix + input
|
136 |
+
messages = [
|
137 |
+
{
|
138 |
+
"role": "system",
|
139 |
+
"content": system_prompt,
|
140 |
+
},
|
141 |
+
{"role": "user", "content": prompt},
|
142 |
+
]
|
143 |
+
|
144 |
+
input_ids = tokenizer.apply_chat_template(
|
145 |
+
messages,
|
146 |
+
# add_generation_prompt=True,
|
147 |
+
return_tensors="pt",
|
148 |
+
).to(model.device)
|
149 |
+
|
150 |
+
terminators = [
|
151 |
+
tokenizer.eos_token_id,
|
152 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
153 |
+
]
|
154 |
+
|
155 |
+
outputs = model.generate(
|
156 |
+
input_ids,
|
157 |
+
max_new_tokens=1200,
|
158 |
+
eos_token_id=terminators,
|
159 |
+
# num_beams=8,
|
160 |
+
do_sample=True,
|
161 |
+
temperature=temperature,
|
162 |
+
top_p=top_p,
|
163 |
+
)
|
164 |
+
response = outputs[0][input_ids.shape[-1] :]
|
165 |
+
prediction_response = tokenizer.decode(response, skip_special_tokens=True)
|
166 |
+
|
167 |
+
return prediction_response
|
168 |
+
|
169 |
+
|
170 |
+
def batch_generate_prediction(
|
171 |
+
model,
|
172 |
+
tokenizer,
|
173 |
+
prefix: str,
|
174 |
+
input_ids: torch.Tensor,
|
175 |
+
kind: str,
|
176 |
+
system_prompt: str = "You are an assistant who tags papers according to given schema and "
|
177 |
+
"only returns the tagged phrases in the format as provided in the examples "
|
178 |
+
"without repeating anything else.",
|
179 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
180 |
+
top_p: float = DEFAULT_TOP_P,
|
181 |
+
max_new_tokens: int = 1200,
|
182 |
+
batch_size: int = 1,
|
183 |
+
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
184 |
+
) -> List[str]:
|
185 |
+
all_predictions = []
|
186 |
+
|
187 |
+
# Prepare system message
|
188 |
+
system_message = {"role": "system", "content": system_prompt}
|
189 |
+
|
190 |
+
for i in range(0, input_ids.size(0), batch_size):
|
191 |
+
batch_input_ids = input_ids[i : i + batch_size]
|
192 |
+
|
193 |
+
batch_messages = [
|
194 |
+
[
|
195 |
+
system_message,
|
196 |
+
{
|
197 |
+
"role": "user",
|
198 |
+
"content": prefix + tokenizer.decode(ids, skip_special_tokens=True),
|
199 |
+
},
|
200 |
+
]
|
201 |
+
for ids in batch_input_ids
|
202 |
+
]
|
203 |
+
|
204 |
+
batch_input_ids = tokenizer.apply_chat_template(
|
205 |
+
batch_messages, return_tensors="pt", padding=True, truncation=True
|
206 |
+
).to(device)
|
207 |
+
|
208 |
+
with torch.no_grad():
|
209 |
+
outputs = model.generate(
|
210 |
+
batch_input_ids,
|
211 |
+
max_new_tokens=max_new_tokens,
|
212 |
+
do_sample=True,
|
213 |
+
temperature=temperature,
|
214 |
+
top_p=top_p,
|
215 |
+
pad_token_id=tokenizer.pad_token_id,
|
216 |
+
attention_mask=batch_input_ids.ne(tokenizer.pad_token_id),
|
217 |
+
)
|
218 |
+
|
219 |
+
for output in outputs:
|
220 |
+
response = output[batch_input_ids.size(1) :]
|
221 |
+
prediction_response = tokenizer.decode(response, skip_special_tokens=True)
|
222 |
+
all_predictions.append(prediction_response)
|
223 |
+
|
224 |
+
torch.cuda.empty_cache()
|
225 |
+
|
226 |
+
return all_predictions
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (161 Bytes). View file
|
|
src/utils/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (8.47 kB). View file
|
|
src/utils/utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from config import DEFAULT_RES_DIR as RES_DIR
|
7 |
+
|
8 |
+
from accelerate import (
|
9 |
+
infer_auto_device_map,
|
10 |
+
init_empty_weights,
|
11 |
+
Accelerator,
|
12 |
+
load_checkpoint_and_dispatch,
|
13 |
+
)
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
15 |
+
|
16 |
+
|
17 |
+
def save_results(
|
18 |
+
out_dir_path,
|
19 |
+
all_inputs,
|
20 |
+
gold_tags,
|
21 |
+
predicted_responses,
|
22 |
+
predicted_tags,
|
23 |
+
metrics,
|
24 |
+
runtype,
|
25 |
+
append=False,
|
26 |
+
):
|
27 |
+
mode = "a" if append else "w"
|
28 |
+
|
29 |
+
with open(
|
30 |
+
os.path.join(RES_DIR, out_dir_path, "prompts.txt"), mode, encoding="utf-8"
|
31 |
+
) as f:
|
32 |
+
for input, gold_tag, pred_response, pred_tag in zip(
|
33 |
+
all_inputs, gold_tags, predicted_responses, predicted_tags
|
34 |
+
):
|
35 |
+
f.write(f"{input}\n")
|
36 |
+
f.write(f"True Tag: {gold_tag}\n")
|
37 |
+
f.write(f"Predicted Response: {pred_response}\n")
|
38 |
+
f.write(f"Predicted Tag: {pred_tag}\n")
|
39 |
+
f.write("#" * 50 + "\n")
|
40 |
+
|
41 |
+
with open(
|
42 |
+
os.path.join(RES_DIR, out_dir_path, "predicted_responses.txt"),
|
43 |
+
mode,
|
44 |
+
encoding="utf-8",
|
45 |
+
) as f:
|
46 |
+
for response in predicted_responses:
|
47 |
+
f.write(f"{response}\n")
|
48 |
+
f.write("#" * 50 + "\n")
|
49 |
+
|
50 |
+
if append:
|
51 |
+
with open(os.path.join(RES_DIR, out_dir_path, "predictions.json"), "r+") as f:
|
52 |
+
data = json.load(f)
|
53 |
+
data["predicted_tags"].extend(predicted_tags)
|
54 |
+
f.seek(0)
|
55 |
+
json.dump(data, f, indent=4)
|
56 |
+
f.truncate()
|
57 |
+
else:
|
58 |
+
with open(os.path.join(RES_DIR, out_dir_path, "predictions.json"), "w") as f:
|
59 |
+
json.dump({"predicted_tags": predicted_tags}, f, indent=4)
|
60 |
+
|
61 |
+
if runtype == "eval":
|
62 |
+
if append:
|
63 |
+
with open(
|
64 |
+
os.path.join(RES_DIR, out_dir_path, "ground_truth.json"), "r+"
|
65 |
+
) as f:
|
66 |
+
data = json.load(f)
|
67 |
+
data["gold_tags"].extend(gold_tag)
|
68 |
+
f.seek(0)
|
69 |
+
json.dump(data, f, indent=4)
|
70 |
+
f.truncate()
|
71 |
+
else:
|
72 |
+
with open(
|
73 |
+
os.path.join(RES_DIR, out_dir_path, "ground_truth.json"), "w"
|
74 |
+
) as f:
|
75 |
+
json.dump({"gold_tags": gold_tags}, f, indent=4)
|
76 |
+
|
77 |
+
with open(os.path.join(RES_DIR, out_dir_path, "metrics.json"), "w") as f:
|
78 |
+
json.dump({"metrics": metrics, "prompt_file": "prompts.txt"}, f, indent=4)
|
79 |
+
|
80 |
+
logging.info(f"Results saved in: {os.path.join(RES_DIR, out_dir_path)}")
|
81 |
+
|
82 |
+
|
83 |
+
def save_best_config(metrics, config):
|
84 |
+
best_config_path = os.path.join(RES_DIR, "best_config.json")
|
85 |
+
if os.path.exists(best_config_path):
|
86 |
+
with open(best_config_path, "r") as f:
|
87 |
+
best_config = json.load(f)
|
88 |
+
if metrics["precision"] > best_config["metrics"]["precision"]:
|
89 |
+
best_config = {"metrics": metrics, "config": config}
|
90 |
+
else:
|
91 |
+
best_config = {"metrics": metrics, "config": config}
|
92 |
+
|
93 |
+
with open(best_config_path, "w") as f:
|
94 |
+
json.dump(best_config, f, indent=4)
|
95 |
+
|
96 |
+
|
97 |
+
def load_sweep_config(config_path="sweep_config.json"):
|
98 |
+
with open(config_path, "r") as f:
|
99 |
+
return json.load(f)
|
100 |
+
|
101 |
+
|
102 |
+
# def load_model_and_tokenizer(model_id: str):
|
103 |
+
# accelerator = Accelerator()
|
104 |
+
|
105 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
106 |
+
# # device_map = infer_auto_device_map(model, max_memory=max_memory)
|
107 |
+
|
108 |
+
# if tokenizer.pad_token_id is None:
|
109 |
+
# tokenizer.pad_token_id = tokenizer.eos_token_id
|
110 |
+
|
111 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
112 |
+
# model_id,
|
113 |
+
# torch_dtype=torch.bfloat16,
|
114 |
+
# device_map="auto",
|
115 |
+
# token=os.getenv("HF_TOKEN"),
|
116 |
+
# )
|
117 |
+
|
118 |
+
# model, tokenizer = accelerator.prepare(model, tokenizer)
|
119 |
+
|
120 |
+
# return model, tokenizer
|
121 |
+
|
122 |
+
|
123 |
+
def clear_cuda_cache():
|
124 |
+
if torch.cuda.is_available():
|
125 |
+
torch.cuda.empty_cache()
|
126 |
+
torch.cuda.memory.reset_max_memory_allocated()
|
127 |
+
torch.cuda.memory.reset_max_memory_cached()
|
128 |
+
|
129 |
+
|
130 |
+
def load_model_and_tokenizer(model_id):
|
131 |
+
# Set up memory-saving options
|
132 |
+
torch.cuda.empty_cache()
|
133 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
134 |
+
torch.backends.cudnn.allow_tf32 = True
|
135 |
+
|
136 |
+
# Initialize tokenizer
|
137 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
138 |
+
model_id, padding_side="left", use_auth_token=os.getenv("HF_TOKEN")
|
139 |
+
)
|
140 |
+
if tokenizer.pad_token_id is None:
|
141 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
142 |
+
|
143 |
+
# Load configuration
|
144 |
+
config = AutoConfig.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
|
145 |
+
|
146 |
+
# Load model
|
147 |
+
model = AutoModelForCausalLM.from_pretrained(
|
148 |
+
model_id,
|
149 |
+
config=config,
|
150 |
+
torch_dtype=torch.float16,
|
151 |
+
use_auth_token=os.getenv("HF_TOKEN"),
|
152 |
+
device_map="auto",
|
153 |
+
)
|
154 |
+
|
155 |
+
return model, tokenizer
|