PereLluis13 commited on
Commit
4736d5d
·
1 Parent(s): d8d3783

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -39
README.md CHANGED
@@ -20,11 +20,11 @@ language:
20
  - zh
21
  widget:
22
  - text: >-
23
- The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea,
24
- guitarist Hillel Slovak and drummer Jack Irons.
25
  parameters:
26
  decoder_start_token_id: 250058
27
- src_lang: "en_XX"
 
28
  tags:
29
  - seq2seq
30
  - relation-extraction
@@ -57,31 +57,36 @@ Be aware that the inference widget at the right does not output special tokens,
57
  ```python
58
  from transformers import pipeline
59
 
60
- triplet_extractor = pipeline('text2text-generation', model='Babelscape/mrebel-large', tokenizer='Babelscape/mrebel-large')
61
  # We need to use the tokenizer manually since we need special tokens.
62
- extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor("The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea, guitarist Hillel Slovak and drummer Jack Irons.", return_tensors=True, return_text=False)[0]["generated_token_ids"]])
63
  print(extracted_text[0])
64
  # Function to parse the generated text and extract the triplets
65
- def extract_triplets(text):
66
  triplets = []
67
- relation, subject, relation, object_ = '', '', '', ''
68
  text = text.strip()
69
  current = 'x'
70
- for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
71
- if token == "<triplet>":
 
 
72
  current = 't'
73
  if relation != '':
74
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
75
  relation = ''
76
  subject = ''
77
- elif token == "<subj>":
78
- current = 's'
79
- if relation != '':
80
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
81
- object_ = ''
82
- elif token == "<obj>":
83
- current = 'o'
84
- relation = ''
 
 
 
85
  else:
86
  if current == 't':
87
  subject += ' ' + token
@@ -89,10 +94,10 @@ def extract_triplets(text):
89
  object_ += ' ' + token
90
  elif current == 'o':
91
  relation += ' ' + token
92
- if subject != '' and relation != '' and object_ != '':
93
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
94
  return triplets
95
- extracted_triplets = extract_triplets(extracted_text[0])
96
  print(extracted_triplets)
97
  ```
98
 
@@ -101,26 +106,31 @@ print(extracted_triplets)
101
  ```python
102
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
103
 
104
- def extract_triplets(text):
105
  triplets = []
106
- relation, subject, relation, object_ = '', '', '', ''
107
  text = text.strip()
108
  current = 'x'
109
- for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
110
- if token == "<triplet>":
 
 
111
  current = 't'
112
  if relation != '':
113
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
114
  relation = ''
115
  subject = ''
116
- elif token == "<subj>":
117
- current = 's'
118
- if relation != '':
119
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
120
- object_ = ''
121
- elif token == "<obj>":
122
- current = 'o'
123
- relation = ''
 
 
 
124
  else:
125
  if current == 't':
126
  subject += ' ' + token
@@ -128,18 +138,19 @@ def extract_triplets(text):
128
  object_ += ' ' + token
129
  elif current == 'o':
130
  relation += ' ' + token
131
- if subject != '' and relation != '' and object_ != '':
132
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
133
  return triplets
134
 
135
  # Load model and tokenizer
136
- tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
137
- model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
138
  gen_kwargs = {
139
  "max_length": 256,
140
  "length_penalty": 0,
141
  "num_beams": 3,
142
  "num_return_sequences": 3,
 
143
  }
144
 
145
  # Text to extract triplets from
@@ -152,6 +163,7 @@ model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, re
152
  generated_tokens = model.generate(
153
  model_inputs["input_ids"].to(model.device),
154
  attention_mask=model_inputs["attention_mask"].to(model.device),
 
155
  **gen_kwargs,
156
  )
157
 
@@ -161,5 +173,5 @@ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=Fal
161
  # Extract triplets
162
  for idx, sentence in enumerate(decoded_preds):
163
  print(f'Prediction triplets sentence {idx}')
164
- print(extract_triplets(sentence))
165
  ```
 
20
  - zh
21
  widget:
22
  - text: >-
23
+ Els Red Hot Chili Peppers es van formar a Los Angeles per Kiedis, Flea, el guitarrista Hillel Slovak i el bateria Jack Irons.
 
24
  parameters:
25
  decoder_start_token_id: 250058
26
+ src_lang: "ca_XX"
27
+ tgt_lang: "<triplet>"
28
  tags:
29
  - seq2seq
30
  - relation-extraction
 
57
  ```python
58
  from transformers import pipeline
59
 
60
+ triplet_extractor = pipeline('translation_xx_to_yy', model='Babelscape/mrebel-large', tokenizer='Babelscape/mrebel-large')
61
  # We need to use the tokenizer manually since we need special tokens.
62
+ extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor("The Red Hot Chili Peppers were formed in Los Angeles by Kiedis, Flea, guitarist Hillel Slovak and drummer Jack Irons.", decoder_start_token_id=250058, src_lang="en_XX", tgt_lang="<triplet>", return_tensors=True, return_text=False)[0]["translation_token_ids"]]) # change en_XX for the language of the source.
63
  print(extracted_text[0])
64
  # Function to parse the generated text and extract the triplets
65
+ def extract_triplets_typed(text):
66
  triplets = []
67
+ relation = ''
68
  text = text.strip()
69
  current = 'x'
70
+ subject, relation, object_, object_type, subject_type = '','','','',''
71
+
72
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
73
+ if token == "<triplet>" or token == "<relation>":
74
  current = 't'
75
  if relation != '':
76
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
77
  relation = ''
78
  subject = ''
79
+ elif token.startswith("<") and token.endswith(">"):
80
+ if current == 't' or current == 'o':
81
+ current = 's'
82
+ if relation != '':
83
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
84
+ object_ = ''
85
+ subject_type = token[1:-1]
86
+ else:
87
+ current = 'o'
88
+ object_type = token[1:-1]
89
+ relation = ''
90
  else:
91
  if current == 't':
92
  subject += ' ' + token
 
94
  object_ += ' ' + token
95
  elif current == 'o':
96
  relation += ' ' + token
97
+ if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
98
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
99
  return triplets
100
+ extracted_triplets = extract_triplets_typed(extracted_text[0])
101
  print(extracted_triplets)
102
  ```
103
 
 
106
  ```python
107
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
108
 
109
+ def extract_triplets_typed(text):
110
  triplets = []
111
+ relation = ''
112
  text = text.strip()
113
  current = 'x'
114
+ subject, relation, object_, object_type, subject_type = '','','','',''
115
+
116
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
117
+ if token == "<triplet>" or token == "<relation>":
118
  current = 't'
119
  if relation != '':
120
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
121
  relation = ''
122
  subject = ''
123
+ elif token.startswith("<") and token.endswith(">"):
124
+ if current == 't' or current == 'o':
125
+ current = 's'
126
+ if relation != '':
127
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
128
+ object_ = ''
129
+ subject_type = token[1:-1]
130
+ else:
131
+ current = 'o'
132
+ object_type = token[1:-1]
133
+ relation = ''
134
  else:
135
  if current == 't':
136
  subject += ' ' + token
 
138
  object_ += ' ' + token
139
  elif current == 'o':
140
  relation += ' ' + token
141
+ if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
142
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
143
  return triplets
144
 
145
  # Load model and tokenizer
146
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-large", src_lang="en_XX", "tgt_lang": "tp_XX") # Here we set English as source language. To change the source language just change it here or swap the first token of the input for your desired language
147
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large")
148
  gen_kwargs = {
149
  "max_length": 256,
150
  "length_penalty": 0,
151
  "num_beams": 3,
152
  "num_return_sequences": 3,
153
+ "forced_bos_token_id": None,
154
  }
155
 
156
  # Text to extract triplets from
 
163
  generated_tokens = model.generate(
164
  model_inputs["input_ids"].to(model.device),
165
  attention_mask=model_inputs["attention_mask"].to(model.device),
166
+ decoder_start_token_id = self.tokenizer.convert_tokens_to_ids("tp_XX"),
167
  **gen_kwargs,
168
  )
169
 
 
173
  # Extract triplets
174
  for idx, sentence in enumerate(decoded_preds):
175
  print(f'Prediction triplets sentence {idx}')
176
+ print(extract_triplets_typed(sentence))
177
  ```