Spaces:
Runtime error
Runtime error
vinayakdev
commited on
Commit
•
918bd10
1
Parent(s):
39105ec
Check again
Browse files- generator.py +14 -3
generator.py
CHANGED
@@ -57,9 +57,9 @@ def run_model(input_string, **generator_args):
|
|
57 |
|
58 |
input_string = "generate questions: " + input_string + " </s>"
|
59 |
|
60 |
-
inputs =
|
61 |
|
62 |
-
res = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], **
|
63 |
output = hftokenizer.decode(res[0], skip_special_tokens=True)
|
64 |
output = [item.split("<sep>") for item in output]
|
65 |
output = [o.strip() for o in output[:-1]]
|
@@ -107,7 +107,18 @@ def gen_question(inputs):
|
|
107 |
# string_query = "Hello World"
|
108 |
# gen_question(f"answer: {string_query} context: The first C program said {string_query} "). #The format of the query to generate questions
|
109 |
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def read_file(filepath_name):
|
112 |
with open(text, "r") as infile:
|
113 |
contents = infile.read()
|
|
|
57 |
|
58 |
input_string = "generate questions: " + input_string + " </s>"
|
59 |
|
60 |
+
inputs = tokenize(input_string)
|
61 |
|
62 |
+
res = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], **generate_wargs)
|
63 |
output = hftokenizer.decode(res[0], skip_special_tokens=True)
|
64 |
output = [item.split("<sep>") for item in output]
|
65 |
output = [o.strip() for o in output[:-1]]
|
|
|
107 |
# string_query = "Hello World"
|
108 |
# gen_question(f"answer: {string_query} context: The first C program said {string_query} "). #The format of the query to generate questions
|
109 |
|
110 |
+
def tokenize(inputs) :
|
111 |
+
inputs = self.tokenizer.batch_encode_plus(
|
112 |
+
inputs,
|
113 |
+
max_length="512",
|
114 |
+
add_special_tokens=True,
|
115 |
+
truncation=True,
|
116 |
+
padding="max_length",
|
117 |
+
pad_to_max_length=True,
|
118 |
+
return_tensors="pt"
|
119 |
+
)
|
120 |
+
return inputs
|
121 |
+
|
122 |
def read_file(filepath_name):
|
123 |
with open(text, "r") as infile:
|
124 |
contents = infile.read()
|