ipd commited on
Commit
1a3f627
·
1 Parent(s): d4278f3

limit custom dataset

Browse files
Files changed (1) hide show
  1. models/fm4m.py +8 -8
models/fm4m.py CHANGED
@@ -343,11 +343,11 @@ def single_modal(model,dataset=None, downstream_model=None, params=None, x_train
343
  print("Custom Dataset")
344
  #return
345
  components = dataset.split(",")
346
- train_data = pd.read_csv(components[0])[components[2]]
347
- test_data = pd.read_csv(components[1])[components[2]]
348
 
349
- y_batch = pd.read_csv(components[0])[components[3]]
350
- y_batch_test = pd.read_csv(components[1])[components[3]]
351
 
352
 
353
  x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
@@ -610,11 +610,11 @@ def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_tr
610
  elif x_train==None:
611
  predefined = False
612
  components = dataset.split(",")
613
- train_data = pd.read_csv(components[0])[components[2]]
614
- test_data = pd.read_csv(components[1])[components[2]]
615
 
616
- y_batch = pd.read_csv(components[0])[components[3]]
617
- y_batch_test = pd.read_csv(components[1])[components[3]]
618
 
619
  print("Custom Dataset loaded")
620
  else:
 
343
  print("Custom Dataset")
344
  #return
345
  components = dataset.split(",")
346
+ train_data = pd.read_csv(components[0])[components[2]][:100]
347
+ test_data = pd.read_csv(components[1])[components[2]][:50]
348
 
349
+ y_batch = pd.read_csv(components[0])[components[3]][:100]
350
+ y_batch_test = pd.read_csv(components[1])[components[3]][:50]
351
 
352
 
353
  x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
 
610
  elif x_train==None:
611
  predefined = False
612
  components = dataset.split(",")
613
+ train_data = pd.read_csv(components[0])[components[2]][:100]
614
+ test_data = pd.read_csv(components[1])[components[2]][:50]
615
 
616
+ y_batch = pd.read_csv(components[0])[components[3]][:100]
617
+ y_batch_test = pd.read_csv(components[1])[components[3]][:50]
618
 
619
  print("Custom Dataset loaded")
620
  else: