Abhipsha Das commited on
Commit
a2b5ed5
·
unverified ·
1 Parent(s): 8fbb714
data/databases/README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - This folder contains all the SQL databases for the different processed data along with their raw data.
2
+
3
+ - The databases are named after the arXiv category and the format of the generated data.
4
+
5
+ Each file in this folder is a database containing 2 tables:
6
+ - **papers**
7
+
8
+ The papers data from the `raw` folder that was fed to the model.
9
+
10
+ SCHEMA:
11
+ - paper_id TEXT PRIMARY KEY,
12
+ - abstract TEXT,
13
+ - authors TEXT,
14
+ - primary_category TEXT,
15
+ - url TEXT,
16
+ - updated_on TEXT,
17
+ - sentence_count INTEGER
18
+
19
+ - **predictions**
20
+
21
+ The corresponding model generations stored in the `results` folder.
22
+
23
+ SCHEMA:
24
+ - id INTEGER PRIMARY KEY AUTOINCREMENT,
25
+ - paper_id TEXT,
26
+ - sentence_index INTEGER,
27
+ - tag_type TEXT,
28
+ - concept TEXT,
29
+ - FOREIGN KEY (paper_id) REFERENCES papers(paper_id)
30
+
31
+
32
+ To query any database, open SQLite in your terminal and specify the database name.
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (155 Bytes). View file
 
src/eval/__init__.py ADDED
File without changes
src/eval/metrics.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+
4
+ def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]:
5
+ """
6
+ Returns true positives, false positives, and false negatives for one example
7
+ If union is True, then disregards the type of the tag and only considers the union of all tags
8
+ """
9
+ n_tp = 0
10
+ n_fp = 0
11
+ n_fn = 0
12
+ if union:
13
+ gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases)
14
+ pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases)
15
+ n_tp = len(gold_phrases & pred_phrases)
16
+ n_fp = len(pred_phrases - gold_phrases)
17
+ n_fn = len(gold_phrases - pred_phrases)
18
+ return n_tp, n_fp, n_fn
19
+
20
+ for tag in set(gold.keys()).union(pred.keys()):
21
+ gold_phrases = set(gold.get(tag, []))
22
+ pred_phrases = set(pred.get(tag, []))
23
+
24
+ n_tp += len(gold_phrases & pred_phrases)
25
+ n_fp += len(pred_phrases - gold_phrases)
26
+ n_fn += len(gold_phrases - pred_phrases)
27
+
28
+ return n_tp, n_fp, n_fn
29
+
30
+
31
+ def compute_metrics(running_time, pred_times, runtype, eval_metrics=None):
32
+ metrics = {}
33
+ metrics["avg_pred_response_time_per_sentence"] = (
34
+ round(sum(pred_times) / len(pred_times), 4) if pred_times else 0
35
+ )
36
+ metrics["total_time"] = round(running_time, 4)
37
+
38
+ if runtype == "eval" and eval_metrics is not None:
39
+ n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics
40
+
41
+ precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4)
42
+ recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4)
43
+ f1 = round(
44
+ (
45
+ 2 * (precision * recall) / (precision + recall)
46
+ if (precision + recall) > 0
47
+ else 0
48
+ ),
49
+ 4,
50
+ )
51
+ union_precision = round(
52
+ (
53
+ n_tp_union / (n_tp_union + n_fp_union)
54
+ if (n_tp_union + n_fp_union) > 0
55
+ else 0
56
+ ),
57
+ 4,
58
+ )
59
+ union_recall = round(
60
+ (
61
+ n_tp_union / (n_tp_union + n_fn_union)
62
+ if (n_tp_union + n_fn_union) > 0
63
+ else 0
64
+ ),
65
+ 4,
66
+ )
67
+ union_f1 = round(
68
+ (
69
+ 2 * (union_precision * union_recall) / (union_precision + union_recall)
70
+ if (union_precision + union_recall) > 0
71
+ else 0
72
+ ),
73
+ 4,
74
+ )
75
+
76
+ metrics.update(
77
+ {
78
+ "precision": precision,
79
+ "recall": recall,
80
+ "f1": f1,
81
+ "union_precision": union_precision,
82
+ "union_recall": union_recall,
83
+ "union_f1": union_f1,
84
+ }
85
+ )
86
+
87
+ return metrics
src/processing/__init__.py ADDED
File without changes
src/processing/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (166 Bytes). View file
 
src/processing/__pycache__/extractions.cpython-311.pyc ADDED
Binary file (4.36 kB). View file
 
src/processing/__pycache__/generate.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
src/processing/extractions.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ from bs4 import BeautifulSoup
5
+ from collections import defaultdict
6
+ from typing import Dict, List
7
+
8
+
9
+ # TODO: review the functions here
10
+ def extract_all_tagged_phrases(text: str) -> Dict[str, List[str]]:
11
+ soup = BeautifulSoup(text, "html.parser")
12
+ tagged_phrases = defaultdict(list)
13
+
14
+ for tag in soup.find_all(True):
15
+ if tag.name:
16
+ # Clean and process the text
17
+ full_text = " ".join(tag.stripped_strings)
18
+ full_text = re.sub(r"\s+", " ", full_text.strip())
19
+ full_text = re.sub(r'(?<!\\)\\(?!["\\])', r"\\\\", full_text)
20
+ full_text = full_text.replace('"', '\\"')
21
+
22
+ if full_text: # Only add non-empty strings
23
+ tagged_phrases[tag.name].append(full_text)
24
+
25
+ # Remove duplicates while preserving order
26
+ return {
27
+ tag: list(dict.fromkeys(phrases)) for tag, phrases in tagged_phrases.items()
28
+ }
29
+
30
+
31
+ def extract_prediction(schema: dict, prediction: str, kind: str = "json") -> dict:
32
+ pred = {}
33
+ if kind == "json":
34
+ json_match = re.search(r"\{[\s\S]+\}", prediction)
35
+ if json_match:
36
+ json_str = json_match.group(0)
37
+ json_str = re.sub(r"(\w+)-\$?\\?(\w+)\$?", r"\1-\2", json_str)
38
+ json_str = json_str.replace('\\"', '"')
39
+ json_str = re.sub(r'}\s*"', '}, "', json_str)
40
+ json_str = re.sub(r']\s*"', '], "', json_str)
41
+ try:
42
+ pred = json.loads(json_str)
43
+ except json.JSONDecodeError as e:
44
+ logging.warning(f"Failed to parse JSON: {json_str}")
45
+ logging.warning(f"Error: {str(e)}")
46
+
47
+ try:
48
+ json_str = re.sub(r",\s*([}\]])", r"\1", json_str)
49
+ json_str = re.sub(r"(?<![\w'])'|'(?![\w'])", '"', json_str)
50
+ pred = json.loads(json_str)
51
+ except json.JSONDecodeError:
52
+ logging.error(
53
+ f"Failed to parse JSON even after attempted fixes: {json_str}"
54
+ )
55
+ elif kind == "readable":
56
+ match = re.findall(
57
+ rf'^({"|".join(list(schema.keys()))}): (.+)$',
58
+ prediction,
59
+ flags=re.MULTILINE,
60
+ )
61
+ pred = {tag: values.split(", ") for tag, values in match}
62
+ else:
63
+ raise ValueError(f"Invalid kind: {kind}")
64
+
65
+ return pred
src/processing/generate.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import re
4
+
5
+ # import spacy
6
+ import torch
7
+
8
+ from config import (
9
+ DEFAULT_FEW_SHOT_NUM,
10
+ DEFAULT_FEW_SHOT_SELECTION,
11
+ DEFAULT_TEMPERATURE,
12
+ DEFAULT_TOP_P,
13
+ DEFAULT_KIND,
14
+ )
15
+ from typing import List, Dict, Tuple, Union
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
17
+
18
+ from .extractions import extract_all_tagged_phrases
19
+
20
+ # nlp = spacy.load("en_core_web_sm")
21
+
22
+
23
+ # TODO: run with constituency tests
24
+ # TODO: review instruction and system level prompt (currently they are repetitive)
25
+ def get_sentences(text: str) -> List[str]:
26
+ # TODO: spacy splitting results in unequal lengths
27
+ # doc = nlp(text)
28
+ # sentences = [sent.text.strip() for sent in doc.sents]
29
+ # sentences = [s for s in sentences if s]
30
+ # return sentences
31
+
32
+ return text.split(". ")
33
+
34
+
35
+ def format_instance(sentence: str, extraction: Union[str, None]) -> str:
36
+ return "".join(
37
+ [
38
+ f"Sentence: {sentence}\n",
39
+ (
40
+ f"Extractions:\n{extraction}\n"
41
+ if extraction is not None
42
+ else f"Extractions:\n"
43
+ ),
44
+ ]
45
+ )
46
+
47
+
48
+ def generate_instructions(schema: dict, kind: str = DEFAULT_KIND) -> str:
49
+ instruction_parts = [
50
+ "The following schema is provided to tag the title and abstract of a given scientific paper as shown in the examples:\n"
51
+ ]
52
+ if kind == "json":
53
+ instruction_parts.append(f"{json.dumps(schema, indent=2)}\n\n")
54
+ elif kind == "readable":
55
+ readable_schema = ""
56
+ for tag, description in schema.items():
57
+ readable_schema += f"{tag}: {description}\n"
58
+ instruction_parts.append(f"{readable_schema}\n")
59
+ else:
60
+ raise ValueError(f"Invalid kind: {kind}")
61
+
62
+ return "".join(instruction_parts)
63
+
64
+
65
+ def generate_demonstrations(
66
+ examples: List[dict],
67
+ kind: str = DEFAULT_KIND,
68
+ num_examples: int = DEFAULT_FEW_SHOT_NUM,
69
+ selection: str = DEFAULT_FEW_SHOT_SELECTION,
70
+ ) -> str:
71
+ demonstration_parts = []
72
+ for example in examples:
73
+ sentences = get_sentences(example["abstract"])
74
+ tagged_sentences = get_sentences(example["tagged_abstract"])
75
+ paired_sentences = list(zip(sentences, tagged_sentences, strict=True))
76
+
77
+ if selection == "random":
78
+ selected_pairs = random.sample(
79
+ paired_sentences, min(num_examples, len(paired_sentences))
80
+ )
81
+ elif selection == "first":
82
+ selected_pairs = paired_sentences[:num_examples]
83
+ elif selection == "last":
84
+ selected_pairs = paired_sentences[-num_examples:]
85
+ elif selection == "middle":
86
+ start = max(0, (len(paired_sentences) - num_examples) // 2)
87
+ selected_pairs = paired_sentences[start : start + num_examples]
88
+ elif selection == "distributed":
89
+ step = max(1, len(paired_sentences) // num_examples)
90
+ selected_pairs = paired_sentences[::step][:num_examples]
91
+ elif selection == "longest":
92
+ selected_pairs = sorted(
93
+ paired_sentences, key=lambda x: len(x[0]), reverse=True
94
+ )[:num_examples]
95
+ elif selection == "shortest":
96
+ selected_pairs = sorted(paired_sentences, key=lambda x: len(x[0]))[
97
+ :num_examples
98
+ ]
99
+ else:
100
+ raise ValueError(f"Invalid selection method: {selection}")
101
+
102
+ for sentence, tagged_sentence in selected_pairs:
103
+ tag_to_phrase = extract_all_tagged_phrases(tagged_sentence)
104
+ if kind == "json":
105
+ extractions = f"{json.dumps(tag_to_phrase, indent=2)}\n"
106
+ elif kind == "readable":
107
+ extractions = "".join(
108
+ f"{tag}: {', '.join(phrase)}\n"
109
+ for tag, phrase in tag_to_phrase.items()
110
+ )
111
+ else:
112
+ raise ValueError(f"Invalid kind: {kind}")
113
+
114
+ demonstration_parts.append(format_instance(sentence, extractions))
115
+
116
+ return "".join(demonstration_parts)
117
+
118
+
119
+ def generate_prefix(instructions: str, demonstrations: str) -> str:
120
+ return f"{instructions}" f"{demonstrations}"
121
+
122
+
123
+ def generate_prediction(
124
+ model,
125
+ tokenizer,
126
+ prefix: str,
127
+ input: str,
128
+ kind: str,
129
+ system_prompt: str = f"You are an assistant who tags papers according to given schema and "
130
+ "only returns the tagged phrases in the format as provided in the examples "
131
+ "without repeating anything else.",
132
+ temperature: float = DEFAULT_TEMPERATURE,
133
+ top_p: float = DEFAULT_TOP_P,
134
+ ) -> str:
135
+ prompt = prefix + input
136
+ messages = [
137
+ {
138
+ "role": "system",
139
+ "content": system_prompt,
140
+ },
141
+ {"role": "user", "content": prompt},
142
+ ]
143
+
144
+ input_ids = tokenizer.apply_chat_template(
145
+ messages,
146
+ # add_generation_prompt=True,
147
+ return_tensors="pt",
148
+ ).to(model.device)
149
+
150
+ terminators = [
151
+ tokenizer.eos_token_id,
152
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
153
+ ]
154
+
155
+ outputs = model.generate(
156
+ input_ids,
157
+ max_new_tokens=1200,
158
+ eos_token_id=terminators,
159
+ # num_beams=8,
160
+ do_sample=True,
161
+ temperature=temperature,
162
+ top_p=top_p,
163
+ )
164
+ response = outputs[0][input_ids.shape[-1] :]
165
+ prediction_response = tokenizer.decode(response, skip_special_tokens=True)
166
+
167
+ return prediction_response
168
+
169
+
170
+ def batch_generate_prediction(
171
+ model,
172
+ tokenizer,
173
+ prefix: str,
174
+ input_ids: torch.Tensor,
175
+ kind: str,
176
+ system_prompt: str = "You are an assistant who tags papers according to given schema and "
177
+ "only returns the tagged phrases in the format as provided in the examples "
178
+ "without repeating anything else.",
179
+ temperature: float = DEFAULT_TEMPERATURE,
180
+ top_p: float = DEFAULT_TOP_P,
181
+ max_new_tokens: int = 1200,
182
+ batch_size: int = 1,
183
+ device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
184
+ ) -> List[str]:
185
+ all_predictions = []
186
+
187
+ # Prepare system message
188
+ system_message = {"role": "system", "content": system_prompt}
189
+
190
+ for i in range(0, input_ids.size(0), batch_size):
191
+ batch_input_ids = input_ids[i : i + batch_size]
192
+
193
+ batch_messages = [
194
+ [
195
+ system_message,
196
+ {
197
+ "role": "user",
198
+ "content": prefix + tokenizer.decode(ids, skip_special_tokens=True),
199
+ },
200
+ ]
201
+ for ids in batch_input_ids
202
+ ]
203
+
204
+ batch_input_ids = tokenizer.apply_chat_template(
205
+ batch_messages, return_tensors="pt", padding=True, truncation=True
206
+ ).to(device)
207
+
208
+ with torch.no_grad():
209
+ outputs = model.generate(
210
+ batch_input_ids,
211
+ max_new_tokens=max_new_tokens,
212
+ do_sample=True,
213
+ temperature=temperature,
214
+ top_p=top_p,
215
+ pad_token_id=tokenizer.pad_token_id,
216
+ attention_mask=batch_input_ids.ne(tokenizer.pad_token_id),
217
+ )
218
+
219
+ for output in outputs:
220
+ response = output[batch_input_ids.size(1) :]
221
+ prediction_response = tokenizer.decode(response, skip_special_tokens=True)
222
+ all_predictions.append(prediction_response)
223
+
224
+ torch.cuda.empty_cache()
225
+
226
+ return all_predictions
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (161 Bytes). View file
 
src/utils/__pycache__/utils.cpython-311.pyc ADDED
Binary file (8.47 kB). View file
 
src/utils/utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import torch
5
+
6
+ from config import DEFAULT_RES_DIR as RES_DIR
7
+
8
+ from accelerate import (
9
+ infer_auto_device_map,
10
+ init_empty_weights,
11
+ Accelerator,
12
+ load_checkpoint_and_dispatch,
13
+ )
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
15
+
16
+
17
+ def save_results(
18
+ out_dir_path,
19
+ all_inputs,
20
+ gold_tags,
21
+ predicted_responses,
22
+ predicted_tags,
23
+ metrics,
24
+ runtype,
25
+ append=False,
26
+ ):
27
+ mode = "a" if append else "w"
28
+
29
+ with open(
30
+ os.path.join(RES_DIR, out_dir_path, "prompts.txt"), mode, encoding="utf-8"
31
+ ) as f:
32
+ for input, gold_tag, pred_response, pred_tag in zip(
33
+ all_inputs, gold_tags, predicted_responses, predicted_tags
34
+ ):
35
+ f.write(f"{input}\n")
36
+ f.write(f"True Tag: {gold_tag}\n")
37
+ f.write(f"Predicted Response: {pred_response}\n")
38
+ f.write(f"Predicted Tag: {pred_tag}\n")
39
+ f.write("#" * 50 + "\n")
40
+
41
+ with open(
42
+ os.path.join(RES_DIR, out_dir_path, "predicted_responses.txt"),
43
+ mode,
44
+ encoding="utf-8",
45
+ ) as f:
46
+ for response in predicted_responses:
47
+ f.write(f"{response}\n")
48
+ f.write("#" * 50 + "\n")
49
+
50
+ if append:
51
+ with open(os.path.join(RES_DIR, out_dir_path, "predictions.json"), "r+") as f:
52
+ data = json.load(f)
53
+ data["predicted_tags"].extend(predicted_tags)
54
+ f.seek(0)
55
+ json.dump(data, f, indent=4)
56
+ f.truncate()
57
+ else:
58
+ with open(os.path.join(RES_DIR, out_dir_path, "predictions.json"), "w") as f:
59
+ json.dump({"predicted_tags": predicted_tags}, f, indent=4)
60
+
61
+ if runtype == "eval":
62
+ if append:
63
+ with open(
64
+ os.path.join(RES_DIR, out_dir_path, "ground_truth.json"), "r+"
65
+ ) as f:
66
+ data = json.load(f)
67
+ data["gold_tags"].extend(gold_tag)
68
+ f.seek(0)
69
+ json.dump(data, f, indent=4)
70
+ f.truncate()
71
+ else:
72
+ with open(
73
+ os.path.join(RES_DIR, out_dir_path, "ground_truth.json"), "w"
74
+ ) as f:
75
+ json.dump({"gold_tags": gold_tags}, f, indent=4)
76
+
77
+ with open(os.path.join(RES_DIR, out_dir_path, "metrics.json"), "w") as f:
78
+ json.dump({"metrics": metrics, "prompt_file": "prompts.txt"}, f, indent=4)
79
+
80
+ logging.info(f"Results saved in: {os.path.join(RES_DIR, out_dir_path)}")
81
+
82
+
83
+ def save_best_config(metrics, config):
84
+ best_config_path = os.path.join(RES_DIR, "best_config.json")
85
+ if os.path.exists(best_config_path):
86
+ with open(best_config_path, "r") as f:
87
+ best_config = json.load(f)
88
+ if metrics["precision"] > best_config["metrics"]["precision"]:
89
+ best_config = {"metrics": metrics, "config": config}
90
+ else:
91
+ best_config = {"metrics": metrics, "config": config}
92
+
93
+ with open(best_config_path, "w") as f:
94
+ json.dump(best_config, f, indent=4)
95
+
96
+
97
+ def load_sweep_config(config_path="sweep_config.json"):
98
+ with open(config_path, "r") as f:
99
+ return json.load(f)
100
+
101
+
102
+ # def load_model_and_tokenizer(model_id: str):
103
+ # accelerator = Accelerator()
104
+
105
+ # tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
106
+ # # device_map = infer_auto_device_map(model, max_memory=max_memory)
107
+
108
+ # if tokenizer.pad_token_id is None:
109
+ # tokenizer.pad_token_id = tokenizer.eos_token_id
110
+
111
+ # model = AutoModelForCausalLM.from_pretrained(
112
+ # model_id,
113
+ # torch_dtype=torch.bfloat16,
114
+ # device_map="auto",
115
+ # token=os.getenv("HF_TOKEN"),
116
+ # )
117
+
118
+ # model, tokenizer = accelerator.prepare(model, tokenizer)
119
+
120
+ # return model, tokenizer
121
+
122
+
123
+ def clear_cuda_cache():
124
+ if torch.cuda.is_available():
125
+ torch.cuda.empty_cache()
126
+ torch.cuda.memory.reset_max_memory_allocated()
127
+ torch.cuda.memory.reset_max_memory_cached()
128
+
129
+
130
+ def load_model_and_tokenizer(model_id):
131
+ # Set up memory-saving options
132
+ torch.cuda.empty_cache()
133
+ torch.backends.cuda.matmul.allow_tf32 = True
134
+ torch.backends.cudnn.allow_tf32 = True
135
+
136
+ # Initialize tokenizer
137
+ tokenizer = AutoTokenizer.from_pretrained(
138
+ model_id, padding_side="left", use_auth_token=os.getenv("HF_TOKEN")
139
+ )
140
+ if tokenizer.pad_token_id is None:
141
+ tokenizer.pad_token_id = tokenizer.eos_token_id
142
+
143
+ # Load configuration
144
+ config = AutoConfig.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
145
+
146
+ # Load model
147
+ model = AutoModelForCausalLM.from_pretrained(
148
+ model_id,
149
+ config=config,
150
+ torch_dtype=torch.float16,
151
+ use_auth_token=os.getenv("HF_TOKEN"),
152
+ device_map="auto",
153
+ )
154
+
155
+ return model, tokenizer