Spaces:
Runtime error
Runtime error
File size: 27,687 Bytes
370675b 9997114 370675b d348741 fbbf27f d348741 fbbf27f d348741 fbbf27f 9997114 2f7253d 9997114 d348741 14d74fa 2f7253d 9997114 2f7253d 9997114 2f7253d 9997114 d348741 14d74fa 9997114 2f7253d 9997114 14d74fa 9997114 2f7253d 9997114 2f7253d 9997114 a362d69 9997114 fbbf27f d348741 2f7253d d348741 fbbf27f d348741 fbbf27f d348741 fbbf27f 9997114 2f7253d fbbf27f 9997114 fbbf27f 370675b 14d74fa 370675b 2f7253d 370675b 2f7253d 370675b 2f7253d 370675b fbbf27f 370675b fbbf27f 370675b fbbf27f 370675b 2f7253d 370675b 2f7253d fbbf27f 370675b 2f7253d fbbf27f 370675b fbbf27f 370675b 2f7253d fbbf27f 370675b 2f7253d 370675b fbbf27f 370675b fbbf27f 370675b fbbf27f 370675b 2f7253d 370675b fbbf27f 370675b fbbf27f 370675b fbbf27f 370675b a880e5e 370675b 14d74fa 370675b 2f7253d 370675b 9997114 2f7253d 9997114 2f7253d 9997114 2f7253d 9997114 a880e5e 9997114 2f7253d 9997114 370675b 2f7253d 370675b 9997114 d348741 9997114 fbbf27f 370675b 9997114 370675b 9997114 d348741 9997114 14d74fa 370675b 14d74fa 370675b 14d74fa a880e5e 370675b 14d74fa 370675b 14d74fa 370675b 14d74fa 370675b 14d74fa 370675b 14d74fa 370675b 9997114 14d74fa 9997114 370675b 14d74fa d348741 14d74fa 2f7253d 14d74fa 370675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 |
#!/usr/bin/env python3
import argparse
import json
import os
import zipfile
import pandas as pd
from catboost import CatBoostClassifier, Pool
MATCH = 0
DELETE_FROM = 1
INSERT_TO = 2
START = 3
FILE_VERSION = 1
class Transmorgrifier:
def train( self, from_sentences, to_sentences, iterations = 4000, device = 'cpu', trailing_context = 7, leading_context = 7, verbose=True ):
"""
Train the Transmorgrifier model. This does not save it to disk but just trains in memory.
Keyword arguments:
from_sentences -- An array of strings for the input sentences.
to_sentences -- An array of strings of the same length as from_sentences which the model is to train to convert to.
iterations -- An integer specifying the number of iterations to convert from or to. (default 4000)
device -- The gpu reference which catboost wants or "cpu". (default cpu)
trailing_context -- The number of characters after the action point to include for context. (default 7)
leading_context -- The number of characters before the action point to include for context. (default 7)
verbose -- Increased the amount of text output during training. (default True)
"""
X,Y = _parse_for_training( from_sentences, to_sentences, num_pre_context_chars=leading_context, num_post_context_chars=trailing_context )
#train and save the action_model
self.action_model = _train_catboost( X, Y['action'], iterations, verbose=verbose, device=device, model_piece='action' )
#and the char model
#slice through where only the action is insert.
insert_indexes = Y['action'] == INSERT_TO
#if there is only one char to insert, we can't train the second model and need to handle that as a boundary case.
if Y['char'][insert_indexes].nunique() > 1:
self.char_model = _train_catboost( X[insert_indexes], Y['char'][insert_indexes], iterations, verbose=verbose, device=device, model_piece='char' )
self.constant_output = None
else:
self.char_model = None
if Y['char'][insert_indexes].nunique() == 1:
self.constant_output = Y['char'][insert_indexes].unique()[0]
else:
#If there is never an insertion handle it as always inserting a space,
#because it will never insert, but it handles the boundary case so the saving and loading code works.
self.constant_output = ' '
self.trailing_context = trailing_context
self.leading_context = leading_context
self.iterations = iterations
return self
def save( self, model='my_model.tm' ):
"""
Saves the model previously trained with train to a specified model file.
Keyword arguments:
model -- The pathname to save the model such as "my_model.tm" (default my_model.tm)
"""
self.name = model
with zipfile.ZipFile( model, mode="w", compression=zipfile.ZIP_DEFLATED, compresslevel=9 ) as my_zip:
with my_zip.open( 'params.json', mode='w' ) as out:
params = {
'version': FILE_VERSION,
'leading_context': self.leading_context,
'trailing_context': self.trailing_context,
'iterations': self.iterations,
}
if self.constant_output is not None:
params['constant_output'] = self.constant_output
out.write( json.dumps(params).encode())
temp_filename = _mktemp()
self.action_model.save_model( temp_filename )
my_zip.write( temp_filename, "action.cb" )
if not self.char_model is None:
self.char_model.save_model( temp_filename )
my_zip.write( temp_filename, "char.cb" )
os.unlink( temp_filename )
return self
def load( self, model='my_model.tm' ):
"""
Loads the model previously saved from the file system.
Keyword arguments:
model -- The filename of the model to load. (default my_model.tm)
"""
self.name = model
with zipfile.ZipFile( model, mode='r' ) as zip:
with zip.open( 'params.json' ) as f_in:
params = json.loads( f_in.read().decode() )
if params['version'] > FILE_VERSION: raise Exception( f"Version {params['version']} greater than {FILE_VERSION}" )
self.leading_context = int(params['leading_context'])
self.trailing_context = int(params['trailing_context'])
self.iterations = int(params['iterations'])
temp_filename = _mktemp()
with zip.open( 'action.cb' ) as f_in:
with open( temp_filename, "wb" ) as f_out:
f_out.write( f_in.read() )
self.action_model = CatBoostClassifier().load_model( temp_filename )
if 'constant_output' not in params:
with zip.open( 'char.cb' ) as f_in:
with open( temp_filename, "wb" ) as f_out:
f_out.write( f_in.read() )
self.char_model = CatBoostClassifier().load_model( temp_filename )
self.constant_output = None
else:
self.constant_output = params['constant_output']
self.char_model = None
os.unlink( temp_filename)
return self
def execute( self, from_sentences, verbose=False ):
"""
Runs the data from from_sentences. The results are returned
using yield so you need to wrap this in list() if you want
to index it. from_sentences can be an array or a generator.
Keyword arguments:
from_sentences -- Something iterable which returns strings.
"""
for i,from_sentence in enumerate(from_sentences):
yield _do_reconstruct(
action_model=self.action_model,
char_model=self.char_model,
constant_output=self.constant_output,
text=from_sentence,
num_pre_context_chars=self.leading_context,
num_post_context_chars=self.trailing_context )
if verbose and i % 10 == 0:
print( f"{i} of {len(from_sentences)}" )
def demo( self, share=False ):
import gradio as gr
def gradio_function( text ):
return list(self.execute( [text] ))[0]
with gr.Blocks() as demo:
name = gr.Markdown( self.name )
inp = gr.Textbox( label="Input" )
out = gr.Textbox( label="Output" )
inp.change( gradio_function, inputs=[inp], outputs=[out] )
demo.launch( share=share )
def _list_trace( trace ):
if trace.parent is None:
result = [trace]
else:
result = _list_trace( trace.parent )
result.append( trace )
return result
class _edit_trace_hop():
parent = None
edit_distance = None
char = None
from_row_i = None
to_column_i = None
action = None
def __str__( self ):
if self.action == START:
return "<start>"
elif self.action == INSERT_TO:
return f"<ins> {self.char}"
elif self.action == DELETE_FROM:
return f"<del> {self.char}"
elif self.action == MATCH:
return f"<match> {self.char}"
return "eh?"
def __repr__( self ):
return self.__str__()
def _trace_edits( from_sentence, to_sentence, print_debug=False ):
#iterating from will be the rows down the left side.
#iterating to will be the columns across the top.
#we will keep one row as we work on the next.
last_row = None
current_row = []
#the index handles one before the index in the string
#to handle the root cases across the top and down the left of the
#match matrix.
for from_row_i in range( len(from_sentence)+1 ):
for to_column_i in range( len(to_sentence )+1 ):
best_option = None
#root case.
if from_row_i == 0 and to_column_i == 0:
best_option = _edit_trace_hop()
best_option.parent = None
best_option.edit_distance = 0
best_option.char = ""
best_option.from_row_i = from_row_i
best_option.to_column_i = to_column_i
best_option.action = START
#check left
if to_column_i > 0:
if best_option is None or current_row[to_column_i-1].edit_distance + 1 < best_option.edit_distance:
best_option = _edit_trace_hop()
best_option.parent = current_row[to_column_i-1]
best_option.edit_distance = best_option.parent.edit_distance + 1
best_option.char = to_sentence[to_column_i-1]
best_option.from_row_i = from_row_i
best_option.to_column_i = to_column_i
best_option.action = INSERT_TO
#check up
if from_row_i > 0:
if best_option is None or last_row[to_column_i].edit_distance + 1 < best_option.edit_distance:
best_option = _edit_trace_hop()
best_option.parent = last_row[to_column_i]
best_option.edit_distance = best_option.parent.edit_distance + 1
best_option.char = from_sentence[from_row_i-1]
best_option.from_row_i = from_row_i
best_option.to_column_i = to_column_i
best_option.action = DELETE_FROM
#check match
if to_column_i > 0:
if to_sentence[to_column_i-1] == from_sentence[from_row_i-1]:
if best_option is None or last_row[to_column_i-1].edit_distance <= best_option.edit_distance: #prefer match so use <= than <
best_option = _edit_trace_hop()
best_option.parent = last_row[to_column_i-1]
best_option.edit_distance = best_option.parent.edit_distance + 1
best_option.char = from_sentence[from_row_i-1]
best_option.from_row_i = from_row_i
best_option.to_column_i = to_column_i
best_option.action = MATCH
if best_option is None: raise Exception( "Shouldn't end up with best_option being None" )
current_row.append(best_option)
last_row = current_row
current_row = []
if print_debug:
def print_diffs( current_node ):
if current_node.parent is not None:
print_diffs( current_node.parent )
if current_node.action == START:
print( "start" )
elif current_node.action == MATCH:
print( f"match {current_node.char}" )
elif current_node.action == INSERT_TO:
print( f"insert {current_node.char}" )
elif current_node.action == DELETE_FROM:
print( f"del {current_node.char}" )
print_diffs( last_row[-1] )
return last_row[-1]
def _parse_single_for_training( from_sentence, to_sentence, num_pre_context_chars, num_post_context_chars ):
trace = _trace_edits( from_sentence, to_sentence )
#we will collect a snapshot at each step.
trace_list = _list_trace(trace)
training_collection = []
#execute these things on the from_sentence and see if we get the to_sentence.
working_from = from_sentence
working_to = ""
used_from = ""
continuous_added = 0
continuous_dropped = 0
for thing in trace_list:
#gather action and context for training
if thing.action != START:
from_context = (working_from + (" " * num_post_context_chars))[:num_post_context_chars]
to_context = ((" " * num_pre_context_chars) + working_to )[-num_pre_context_chars:]
used_context = ((" " * num_pre_context_chars) + used_from )[-num_pre_context_chars:]
training_collection.append({
"from_context": from_context,
"to_context": to_context,
"used_context": used_context,
"action": thing.action,
"continuous_added": continuous_added,
"continuous_dropped": continuous_dropped,
"char": thing.char if thing.action == INSERT_TO else ' ',
})
#now execute the action for the next step.
if thing.action == START:
pass
elif thing.action == INSERT_TO:
working_to += thing.char
continuous_added += 1
continuous_dropped = 0
elif thing.action == DELETE_FROM:
used_from += working_from[0]
working_from = working_from[1:]
continuous_added = 0
continuous_dropped += 1
elif thing.action == MATCH:
used_from += working_from[0]
working_to += working_from[0]
working_from = working_from[1:]
continuous_added = 0
continuous_dropped = 0
if to_sentence != working_to:
print( "Replay failure" )
#so now I have training_collection which is a list of dictionaries where each dictionary is an action with a context.
#I need to change it into a dictionary of lists where each dictionary a column and the lists are the rows.
context_split_into_dict = {}
#first collect the from_context:
for i in range( num_post_context_chars ):
this_slice = []
for training in training_collection:
this_slice.append( training['from_context'][i] )
context_split_into_dict[ f"f{i}" ] = this_slice
#now collect to_context:
for i in range( num_pre_context_chars ):
this_slice = []
for training in training_collection:
this_slice.append( training['to_context'][i] )
context_split_into_dict[ f"t{i}" ] = this_slice
#now collect used_context
for i in range( num_pre_context_chars ):
this_slice = []
for training in training_collection:
this_slice.append( training['used_context'][i] )
context_split_into_dict[ f"u{i}" ] = this_slice
#now these two things.
context_split_into_dict["continuous_added"] = []
context_split_into_dict["continuous_dropped"] = []
for training in training_collection:
context_split_into_dict["continuous_added"].append( training["continuous_added"] )
context_split_into_dict["continuous_dropped"].append( training["continuous_dropped"] )
#now also collect the output answers.
result_split_into_dict = {}
action_slice = []
char_slice = []
for training in training_collection:
action_slice.append( training['action'] )
char_slice.append( training['char'] )
result_split_into_dict['action'] = action_slice
result_split_into_dict['char'] = char_slice
#now return it as a data_frame.
return pd.DataFrame( context_split_into_dict ), pd.DataFrame( result_split_into_dict )
def _parse_for_training( from_sentences, to_sentences, num_pre_context_chars, num_post_context_chars ):
out_observations_list = []
out_results_list = []
for index, (from_sentence, to_sentence) in enumerate(zip( from_sentences, to_sentences )):
if type(from_sentence) != float and type(to_sentence) != float: #bad lines are nan which are floats.
specific_observation, specific_result = _parse_single_for_training( from_sentence, to_sentence, num_pre_context_chars=num_pre_context_chars, num_post_context_chars=num_post_context_chars )
out_observations_list.append( specific_observation )
out_results_list.append( specific_result )
if index % 100 == 0:
print( f"parsing {index} of {len(from_sentences)}")
return pd.concat( out_observations_list ), pd.concat( out_results_list )
def _train_catboost( X, y, iterations, device, verbose, model_piece, learning_rate = .07 ):
X = X.fillna( ' ' )
passed = False
while not passed:
train_pool = Pool(
data=X,
label=y,
cat_features=[i for i,x in enumerate(X.keys()) if x[0] in ['f','t','u']] #watchout if another field is added that it doesn't start with one of these.
)
validation_pool = None #Can't use validation pool because it randomly has chars not in training.
model = CatBoostClassifier(
iterations = iterations,
learning_rate = learning_rate,
task_type="GPU" if device.lower() != 'cpu' else "CPU",
devices=device if device.lower() != 'cpu' else None
)
model.fit( train_pool, eval_set=validation_pool, verbose=True )
passed = True
if( verbose ): print( '{} is fitted: {}'.format(model_piece,model.is_fitted()))
if( verbose ): print( '{} params:\n{}'.format(model_piece,model.get_params()))
return model
def _mktemp():
#I know mktemp exists in the library but it has been deprecated suggesting using
#mkstemp but catboost can't write to a file handle yet, so I need an actual
#filename.
number = 0
while os.path.exists( f".temp_{number}~" ):
number += 1
return f".temp_{number}~"
def predict_wrapper( model, model_input ):
#Big hack. Catboost has shown itself to be unstable on producing
#either a single value or an array with a single value in it.
#I traced it back to the saved model, and then the model to what
#data it is trained on. But I could figure out what it was
#in the data which would make the saved model be one way or the other
#so I am going to use the results this way so that it works either way.
result = model.predict( model_input )[0]
try:
result = result[0]
except:
pass
return result
def _do_reconstruct( action_model, char_model, constant_output, text, num_pre_context_chars, num_post_context_chars ):
#test for nan.
if text != text: text = ''
working_from = text
working_to = ""
used_from = ""
continuous_added = 0
continuous_dropped = 0
while working_from and len(working_to) < 3*len(text) and (len(working_to) < 5 or working_to[-5:] != (working_to[-1] * 5)):
from_context = (working_from + (" " * num_post_context_chars))[:num_post_context_chars]
to_context = ((" " * num_pre_context_chars) + working_to )[-num_pre_context_chars:]
used_context = ((" " * num_pre_context_chars) + used_from )[-num_pre_context_chars:]
#construct the context.
context_as_dictionary = {}
#from_context
for i in range( num_post_context_chars ):
context_as_dictionary[ f"f{i}" ] = [from_context[i]]
#to_context
for i in range( num_pre_context_chars ):
context_as_dictionary[ f"t{i}" ] = [to_context[i]]
#used_context
for i in range( num_pre_context_chars ):
context_as_dictionary[ f"u{i}" ] = [used_context[i]]
#these two things.
context_as_dictionary["continuous_added"] = [continuous_added]
context_as_dictionary["continuous_dropped"] = [continuous_dropped]
#make it a pandas.
context_as_pd = pd.DataFrame( context_as_dictionary )
#run the model
action_model_result = predict_wrapper(action_model,context_as_pd )
#stop run away. If we have added more chars then our context, nothing is going to change.
if action_model_result == INSERT_TO and continuous_added >= num_post_context_chars:
#I can set this to MATCH or DELETE_FROM, but it is already a wreck, lets just see what happens with this.
action_model_result = MATCH
if action_model_result == START:
pass
elif action_model_result == INSERT_TO:
if constant_output is None:
#for an insert ask the char model what to insert
char_model_result = predict_wrapper(char_model, context_as_pd )
else:
char_model_result = constant_output
working_to += char_model_result
continuous_added += 1
continuous_dropped = 0
elif action_model_result == DELETE_FROM:
used_from += working_from[0]
working_from = working_from[1:]
continuous_added = 0
continuous_dropped += 1
elif action_model_result == MATCH:
used_from += working_from[0]
working_to += working_from[0]
working_from = working_from[1:]
continuous_added = 0
continuous_dropped = 0
return working_to
#edit distance from https://stackoverflow.com/a/32558749/1419054
def _levenshteinDistance(s1, s2):
if s1 != s1: s1 = ''
if s2 != s2: s2 = ''
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def train( in_csv, a_header, b_header, model, iterations, device, leading_context, trailing_context, train_percentage, verbose ):
if verbose: print( "loading csv" )
full_data = pd.read_csv( in_csv )
split_index = int( train_percentage/100*len(full_data) )
train_data = full_data.iloc[:split_index,:].reset_index(drop=True)
if verbose: print( "parsing data for training" )
tm = Transmorgrifier()
tm.train( from_sentences=train_data[a_header],
to_sentences=train_data[b_header],
iterations = iterations,
device = device,
leading_context = leading_context,
trailing_context = trailing_context,
verbose=verbose,
)
tm.save( model )
def execute( include_stats, in_csv, out_csv, a_header, b_header, model, execute_percentage, verbose ):
if verbose: print( "loading csv" )
full_data = pd.read_csv( in_csv )
split_index = int( (100-execute_percentage)/100*len(full_data) )
execute_data = full_data.iloc[split_index:,:].reset_index(drop=True)
tm = Transmorgrifier()
tm.load( model )
results = list(tm.execute( execute_data[a_header ], verbose=verbose ))
if include_stats:
before_edit_distances = []
after_edit_distances = []
percent_improvement = []
for row in range(len( execute_data )):
before_edit_distances.append(
_levenshteinDistance( execute_data[a_header][row], execute_data[b_header][row] )
)
after_edit_distances.append(
_levenshteinDistance( results[row], execute_data[b_header][row] )
)
percent_improvement.append(
100*(before_edit_distances[row] - after_edit_distances[row])/max(1,before_edit_distances[row])
)
pd_results = pd.DataFrame( {
"in_data": execute_data[a_header],
"out_data": execute_data[b_header],
"generated_data": results,
"before_edit_distance": before_edit_distances,
"after_edit_distance": after_edit_distances,
"percent_improvement": percent_improvement,
})
pd_results.to_csv( out_csv )
else:
pd_results = pd.DataFrame( {
"out_data": execute_data[b_header],
})
pd_results.to_csv( out_csv )
def safe_float( str ):
if str is not None:
return float(str)
return None #explicit None return.
def main():
parser = argparse.ArgumentParser(
prog = 'transmorgrify.py',
description = 'Converts text from one to another according to a model.',
epilog = '(C) Joshua Lansford')
parser.add_argument('-t', '--train', action='store_true', help='Train a model instead of executing a model')
parser.add_argument('-e', '--execute', action='store_true', help='Use an existing trained model.')
parser.add_argument('-g', '--gradio', action='store_true', help='Start a gradio demo with the selected model.' )
parser.add_argument('-s', '--share', action='store_true', help="Share the gradio app with a temporary public URL." )
parser.add_argument('-i', '--in_csv', help='The csv to read training or input data from', default='in.csv' )
parser.add_argument('-o', '--out_csv', help='The csv to write conversion to', default='out.csv' )
parser.add_argument('-a', '--a_header', help='The column header for training or transforming from', default="source" )
parser.add_argument('-b', '--b_header', help='The column header for training the transformation to', default="target" )
parser.add_argument('-m', '--model',help='The model file to create during training or use during transformation', default='model.tm' )
parser.add_argument('-n', '--iterations', help='The number of iterations to train', default=2000 )
parser.add_argument('-d', '--device', help='Which device, i.e. if using GPU', default='cpu' )
parser.add_argument('-x', '--context', help='The number of leading and trailing chars to use as context', default=7 )
parser.add_argument('-p', '--train_percentage', help="The percentage of data to train on, leaving the rest for testing.")
parser.add_argument('-v', '--verbose', action='store_true', help='Talks alot?' )
parser.add_argument('-c', '--include_stats', action='store_true', help='Use b_header to compute stats and add to output csv.')
args = parser.parse_args()
if not args.train and not args.execute and not args.gradio: print( "Must include --execute, --train and/or --gradio to do something." )
if args.train:
train_percentage = safe_float(args.train_percentage)
if train_percentage is None:
if args.execute:
train_percentage = 50
else:
train_percentage = 100
train( in_csv=args.in_csv,
a_header=args.a_header,
b_header=args.b_header,
model=args.model,
iterations=int(args.iterations),
device=args.device,
leading_context=int(args.context),
trailing_context=int(args.context),
train_percentage=train_percentage,
verbose=args.verbose,
)
if args.execute:
if args.train_percentage is None:
if args.train:
execute_percentage = 50
else:
execute_percentage = 100
else:
execute_percentage = 100-safe_float(args.train_percentage)
execute(
include_stats=args.include_stats,
in_csv=args.in_csv,
out_csv=args.out_csv,
a_header=args.a_header,
b_header=args.b_header,
model=args.model,
execute_percentage=execute_percentage,
verbose=args.verbose,
)
if args.gradio:
tm = Transmorgrifier()
tm.load( args.model )
tm.demo( share=args.share )
if __name__ == '__main__':
main()
|