checking the splitTTS script to make sure other splits have been factored in

This commit is contained in:
Tanushree Tunstall 2022-07-27 15:52:20 +01:00
parent f4cab1fdfb
commit 63c8876764
2 changed files with 21 additions and 21 deletions

View file

@ -47,7 +47,7 @@ njobs = {'n_jobs': os.cpu_count() } # the number of jobs should equal the number
# NOTE: split_type 'none_with_bts' and 'none_only': WORK on complete data ONLY irrespective of data_type # NOTE: split_type 'none_with_bts' and 'none_only': WORK on complete data ONLY irrespective of data_type
def split_tts(ml_input_data def split_tts(ml_input_data
, data_type = ['actual', 'complete'] , data_type = ['actual', 'complete']
, split_type = ['70_30', '80_20', 'sl', 'none_with_bts', 'none_only', 'reverse'] , split_type = ['70_30', '80_20', 'sl', 'rt', 'none_bts', 'none']
, oversampling = True , oversampling = True
, dst_colname = 'dst'# determine how to subset the actual vs reverse data , dst_colname = 'dst'# determine how to subset the actual vs reverse data
, target_colname = 'dst_mode' , target_colname = 'dst_mode'
@ -116,19 +116,7 @@ def split_tts(ml_input_data
if split_type == 'sl': if split_type == 'sl':
tts_test_size = 1/np.sqrt(x_ncols) tts_test_size = 1/np.sqrt(x_ncols)
train_sl = 1 - tts_test_size # for reference train_sl = 1 - tts_test_size # for reference
if split_type == 'none_with_bts': # always on complete data
temp_df_train = ml_input_data[ml_input_data[dst_colname].notna()]
X = temp_df_train.drop(cols_to_dropL, axis = 1)
y = temp_df_train[target_colname]
temp_df_bts = ml_input_data[ml_input_data[dst_colname].isna()]
X_bts = temp_df_bts.drop(cols_to_dropL, axis = 1)
y_bts = temp_df_bts[target_colname]
n_test_data_size = len(X) + len(X_bts)
test_data_shape = X_bts.shape
if split_type == 'rt': # always on complete data if split_type == 'rt': # always on complete data
temp_df_train = ml_input_data[ml_input_data[dst_colname].isna()] temp_df_train = ml_input_data[ml_input_data[dst_colname].isna()]
X = temp_df_train.drop(cols_to_dropL, axis = 1) X = temp_df_train.drop(cols_to_dropL, axis = 1)
@ -138,10 +126,22 @@ def split_tts(ml_input_data
X_bts = temp_df_bts.drop(cols_to_dropL, axis = 1) X_bts = temp_df_bts.drop(cols_to_dropL, axis = 1)
y_bts = temp_df_bts[target_colname] y_bts = temp_df_bts[target_colname]
n_test_data_size = len(X) + len(X_bts)
test_data_shape = X_bts.shape
if split_type == 'none_bts': # always on complete data
temp_df_train = ml_input_data[ml_input_data[dst_colname].notna()]
X = temp_df_train.drop(cols_to_dropL, axis = 1)
y = temp_df_train[target_colname]
temp_df_bts = ml_input_data[ml_input_data[dst_colname].isna()]
X_bts = temp_df_bts.drop(cols_to_dropL, axis = 1)
y_bts = temp_df_bts[target_colname]
n_test_data_size = len(X) + len(X_bts) n_test_data_size = len(X) + len(X_bts)
test_data_shape = X_bts.shape test_data_shape = X_bts.shape
if split_type == 'none_only': if split_type == 'none': # always on complete data
temp_df_train = ml_input_data.copy() # always complete temp_df_train = ml_input_data.copy() # always complete
X = temp_df_train.drop(cols_to_dropL, axis = 1) X = temp_df_train.drop(cols_to_dropL, axis = 1)
y = temp_df_train[target_colname] y = temp_df_train[target_colname]
@ -163,10 +163,10 @@ def split_tts(ml_input_data
yc1 = Counter(y) yc1 = Counter(y)
yc1_ratio = yc1[0]/yc1[1] yc1_ratio = yc1[0]/yc1[1]
if split_type in ['none_only']: if split_type in ['none']:
outDict.update({'X' : X outDict.update({'X' : X
, 'y' : y , 'y' : y })
})
yc2 = "NO Blind test data" yc2 = "NO Blind test data"
yc2_ratio = "NO Blind test data" yc2_ratio = "NO Blind test data"
n_test_data_size = "NO Blind test data" n_test_data_size = "NO Blind test data"

View file

@ -41,9 +41,9 @@ gene_model_paramD = {'data_combined_model' : False
############################################################################### ###############################################################################
#ml_genes = ["pncA", "embB", "katG", "rpoB", "gid"] #ml_genes = ["pncA", "embB", "katG", "rpoB", "gid"]
ml_gene_drugD = {#'pncA' : 'pyrazinamide' ml_gene_drugD = {'pncA' : 'pyrazinamide'
# 'embB' : 'ethambutol' , 'embB' : 'ethambutol'
'katG' : 'isoniazid' , 'katG' : 'isoniazid'
, 'rpoB' : 'rifampicin' , 'rpoB' : 'rifampicin'
, 'gid' : 'streptomycin' , 'gid' : 'streptomycin'
} }