#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat May 28 05:25:30 2022 @author: tanu """ import os gene = 'katG' drug = 'isoniazid' #total_mtblineage_uc = 8 homedir = os.path.expanduser("~") os.chdir( homedir + '/git/LSHTM_analysis/scripts/ml/') from ml_data_sl import * setvars(gene,drug) from ml_data_sl import * # from YC run_all_ML: run locally #from UQ_yc_RunAllClfs import run_all_ML # TT run all ML clfs: baseline mode from MultModelsCl import MultModelsCl ################################################################################ print('\n#####################################################################\n' , '\nRunning ML analysis: scaling law split' , '\nGene name:', gene , '\nDrug name:', drug) #================== # Specify outdir #================== outdir_ml = outdir + 'ml/tts_sl/' print('\nOutput directory:', outdir_ml) #%%########################################################################### print('\nSanity checks:' #, '\nML source data size:', x_features.shape , '\nTotal input features:', len(X.columns) , '\n' , '\nTraining data size:', X.shape , '\nTest data size:', X_bts.shape , '\n' , '\nTarget feature numbers (training data):', Counter(y) , '\nTarget features ratio (training data:', yc1_ratio , '\n' , '\nTarget feature numbers (test data):', Counter(y_bts) , '\nTarget features ratio (test data):', yc2_ratio , '\n\n#####################################################################\n') print('\n================================================================\n') print('Strucutral features (n):' , len(X_ssFN) , '\nThese are:' , '\nCommon stablity features:', X_stabilityN , '\nFoldX columns:', X_foldX_cols , '\nOther struc columns:', X_str , '\n================================================================\n') print('AAindex features (n):' , len(X_aaindexFN) , '\nThese are:\n' , X_aaindexFN , '\n================================================================\n') print('Evolutionary features (n):' , len(X_evolFN) , '\nThese are:\n' , X_evolFN , '\n================================================================\n') print('Genomic features (n):' , len(X_genomicFN) , '\nThese are:\n' , X_genomic_mafor, '\n' , X_genomic_linegae , '\n================================================================\n') print('Categorical features (n):' , len(categorical_FN) , '\nThese are:\n' , categorical_FN , '\n================================================================\n') if ( len(X.columns) == len(X_ssFN) + len(X_aaindexFN) + len(X_evolFN) + len(X_genomicFN) + len(categorical_FN) ): print('\nPass: No. of features match') else: sys.exit('\nFail: Count of feature mismatch') print('\n#####################################################################\n') ################################################################################ #================== # Baseline models #================== mm_skf_scoresD = MultModelsCl(input_df = X , target = y , var_type = 'mixed' , skf_cv = skf_cv , blind_test_input_df = X_bts , blind_test_target = y_bts) baseline_all = pd.DataFrame(mm_skf_scoresD) baseline_all = baseline_all.T #baseline_train = baseline_all.filter(like='train_', axis=1) baseline_CT = baseline_all.filter(like='test_', axis=1) baseline_CT.sort_values(by = ['test_mcc'], ascending = False, inplace = True) baseline_BT = baseline_all.filter(like='bts_', axis=1) baseline_BT.sort_values(by = ['bts_mcc'], ascending = False, inplace = True) # Write csv baseline_CT.to_csv(outdir_ml + gene.lower() + '_baseline_CT_allF.csv') baseline_BT.to_csv(outdir_ml + gene.lower() + '_baseline_BT_allF.csv') #%% SMOTE NC: Oversampling [Numerical + categorical] mm_skf_scoresD7 = MultModelsCl(input_df = X_smnc , target = y_smnc , var_type = 'mixed' , skf_cv = skf_cv , blind_test_input_df = X_bts , blind_test_target = y_bts) smnc_all = pd.DataFrame(mm_skf_scoresD7) smnc_all = smnc_all.T smnc_CT = smnc_all.filter(like='test_', axis=1) smnc_CT.sort_values(by = ['test_mcc'], ascending = False, inplace = True) smnc_BT = smnc_all.filter(like='bts_', axis=1) smnc_BT.sort_values(by = ['bts_mcc'], ascending = False, inplace = True) # Write csv smnc_CT.to_csv(outdir_ml + gene.lower() + '_smnc_CT_allF.csv') smnc_BT.to_csv(outdir_ml + gene.lower() + '_smnc_BT_allF.csv') #%% ROS: Numerical + categorical mm_skf_scoresD3 = MultModelsCl(input_df = X_ros , target = y_ros , var_type = 'mixed' , skf_cv = skf_cv , blind_test_input_df = X_bts , blind_test_target = y_bts) ros_all = pd.DataFrame(mm_skf_scoresD3) ros_all = ros_all.T ros_CT = ros_all.filter(like='test_', axis=1) ros_CT.sort_values(by = ['test_mcc'], ascending = False, inplace = True) ros_BT = ros_all.filter(like='bts_', axis=1) ros_BT.sort_values(by = ['bts_mcc'], ascending = False, inplace = True) # Write csv ros_CT.to_csv(outdir_ml + gene.lower() + '_ros_CT_allF.csv') ros_BT.to_csv(outdir_ml + gene.lower() + '_ros_BT_allF.csv') #%% RUS: Numerical + categorical mm_skf_scoresD4 = MultModelsCl(input_df = X_rus , target = y_rus , var_type = 'mixed' , skf_cv = skf_cv , blind_test_input_df = X_bts , blind_test_target = y_bts) rus_all = pd.DataFrame(mm_skf_scoresD4) rus_all = rus_all.T rus_CT = rus_all.filter(like='test_', axis=1) rus_CT.sort_values(by = ['test_mcc'], ascending = False, inplace = True) rus_BT = rus_all.filter(like='bts_' , axis=1) rus_BT.sort_values(by = ['bts_mcc'], ascending = False, inplace = True) # Write csv rus_CT.to_csv(outdir_ml + gene.lower() + '_rus_CT_allF.csv') rus_BT.to_csv(outdir_ml + gene.lower() + '_rus_BT_allF.csv') #%% ROS + RUS Combined: Numerical + categorical mm_skf_scoresD8 = MultModelsCl(input_df = X_rouC , target = y_rouC , var_type = 'mixed' , skf_cv = skf_cv , blind_test_input_df = X_bts , blind_test_target = y_bts) rouC_all = pd.DataFrame(mm_skf_scoresD8) rouC_all = rouC_all.T rouC_CT = rouC_all.filter(like='test_', axis=1) rouC_CT.sort_values(by = ['test_mcc'], ascending = False, inplace = True) rouC_BT = rouC_all.filter(like='bts_', axis=1) rouC_BT.sort_values(by = ['bts_mcc'], ascending = False, inplace = True) # Write csv rouC_CT.to_csv(outdir_ml + gene.lower() + '_rouC_CT_allF.csv') rouC_BT.to_csv(outdir_ml + gene.lower() + '_rouC_BT_allF.csv')