TangSan003 commited on
Commit
cc370e5
·
verified ·
1 Parent(s): 32c63c8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -44
inference.py CHANGED
@@ -6,7 +6,10 @@ from safetensors.torch import load_file
6
  from datasets import load_dataset
7
  from pprintpp import pprint
8
 
 
9
 
 
 
10
 
11
  class InferenceModel:
12
  """
@@ -14,6 +17,10 @@ class InferenceModel:
14
  """
15
 
16
  def __init__(self, path_to_weights, huggingface_model=True):
 
 
 
 
17
  ### Init Config with either Huggingface Backbone or our own ###
18
  self.config = RobertaConfig(pretrained_backbone="pretrained_huggingface" if huggingface_model else "random")
19
 
@@ -57,48 +64,4 @@ class InferenceModel:
57
 
58
  return prediction
59
 
60
-
61
- if __name__ == "__main__":
62
-
63
- dataset = load_dataset("stanfordnlp/coqa")
64
-
65
- data = dataset["validation"][2]
66
- # data = dataset["train"][0]
67
- # print("answer:", data["answers"])
68
- ### Sample Text ###
69
- context = data["story"]
70
- print("context:", context)
71
- question = data["questions"][4]
72
-
73
- tokenizer = RobertaTokenizerFast.from_pretrained("deepset/roberta-base-squad2")
74
-
75
- encoded = tokenizer(
76
- question,
77
- context,
78
- max_length=512,
79
- truncation="only_second",
80
- padding="max_length",
81
- return_offsets_mapping=True,
82
- return_tensors="pt"
83
- )
84
- offset_mapping = encoded["offset_mapping"][0].tolist() # convert to list of tuples
85
- input_ids = encoded["input_ids"][0]
86
-
87
-
88
- ### Inference Model ###
89
- path_to_weights = "model/RoBERTa/save_model/model.safetensors"
90
- inferencer = InferenceModel(path_to_weights=path_to_weights, huggingface_model=True)
91
- prediction = inferencer.inference_model(question, context)
92
- print("\n----------------------------------")
93
- print("results:", prediction)
94
-
95
- start_token_idx = prediction["start_token_idx"]
96
- end_token_idx = prediction["end_token_idx"]
97
-
98
- start_char = offset_mapping[start_token_idx][0]
99
- end_char = offset_mapping[end_token_idx][1]
100
-
101
- print("Question:", question)
102
- print("Recovered answer:", context[start_char:end_char])
103
-
104
  # test model
 
6
  from datasets import load_dataset
7
  from pprintpp import pprint
8
 
9
+ import os
10
 
11
+ # Đặt biến môi trường HF_HOME
12
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
13
 
14
  class InferenceModel:
15
  """
 
17
  """
18
 
19
  def __init__(self, path_to_weights, huggingface_model=True):
20
+
21
+ self.config = {
22
+ "hf_model_name": "deepset/roberta-base-squad2" # Ví dụ model của bạn
23
+ }
24
  ### Init Config with either Huggingface Backbone or our own ###
25
  self.config = RobertaConfig(pretrained_backbone="pretrained_huggingface" if huggingface_model else "random")
26
 
 
64
 
65
  return prediction
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # test model