minor var bame update in ml_iterator

This commit is contained in:
Tanushree Tunstall 2022-07-09 10:52:50 +01:00
parent 8079dd7b6c
commit 8bde6f0640
3 changed files with 39 additions and 37 deletions

View file

@ -144,10 +144,9 @@ scoreBT_mapD = {'bts_mcc' : 'MCC'
############################ ############################
# Multiple Classification - Model Pipeline # Multiple Classification - Model Pipeline
def MultModelsCl(input_df, target def MultModelsCl(input_df, target
#, skf_cv
, sel_cv , sel_cv
#, blind_test_df , blind_test_df
#, blind_test_target , blind_test_target
, tts_split_type , tts_split_type
, resampling_type = 'none' # default , resampling_type = 'none' # default
@ -230,37 +229,37 @@ def MultModelsCl(input_df, target
#====================================================== #======================================================
# Specify multiple Classification Models # Specify multiple Classification Models
#====================================================== #======================================================
models = [('AdaBoost Classifier' , AdaBoostClassifier(**rs) ) models = [('AdaBoost Classifier' , AdaBoostClassifier(**rs) )
# , ('Bagging Classifier' , BaggingClassifier(**rs, **njobs, bootstrap = True, oob_score = True, verbose = 3, n_estimators = 100) ) , ('Bagging Classifier' , BaggingClassifier(**rs, **njobs, bootstrap = True, oob_score = True, verbose = 3, n_estimators = 100) )
# , ('Decision Tree' , DecisionTreeClassifier(**rs) ) , ('Decision Tree' , DecisionTreeClassifier(**rs) )
# , ('Extra Tree' , ExtraTreeClassifier(**rs) ) , ('Extra Tree' , ExtraTreeClassifier(**rs) )
# , ('Extra Trees' , ExtraTreesClassifier(**rs) ) , ('Extra Trees' , ExtraTreesClassifier(**rs) )
# , ('Gradient Boosting' , GradientBoostingClassifier(**rs) ) , ('Gradient Boosting' , GradientBoostingClassifier(**rs) )
# , ('Gaussian NB' , GaussianNB() ) , ('Gaussian NB' , GaussianNB() )
# , ('Gaussian Process' , GaussianProcessClassifier(**rs) ) , ('Gaussian Process' , GaussianProcessClassifier(**rs) )
# , ('K-Nearest Neighbors' , KNeighborsClassifier() ) , ('K-Nearest Neighbors' , KNeighborsClassifier() )
, ('LDA' , LinearDiscriminantAnalysis() ) , ('LDA' , LinearDiscriminantAnalysis() )
, ('Logistic Regression' , LogisticRegression(**rs) ) , ('Logistic Regression' , LogisticRegression(**rs) )
# , ('Logistic RegressionCV' , LogisticRegressionCV(cv = 3, **rs)) , ('Logistic RegressionCV' , LogisticRegressionCV(cv = 3, **rs))
# , ('MLP' , MLPClassifier(max_iter = 500, **rs) ) , ('MLP' , MLPClassifier(max_iter = 500, **rs) )
#, ('Multinomial' , MultinomialNB() ) , ('Multinomial' , MultinomialNB() )
# , ('Naive Bayes' , BernoulliNB() ) , ('Naive Bayes' , BernoulliNB() )
# , ('Passive Aggresive' , PassiveAggressiveClassifier(**rs, **njobs) ) , ('Passive Aggresive' , PassiveAggressiveClassifier(**rs, **njobs) )
# , ('QDA' , QuadraticDiscriminantAnalysis() ) , ('QDA' , QuadraticDiscriminantAnalysis() )
# , ('Random Forest' , RandomForestClassifier(**rs, n_estimators = 1000, **njobs ) ) # , ('Random Forest' , RandomForestClassifier(**rs, n_estimators = 1000, **njobs ) )
# # , ('Random Forest2' , RandomForestClassifier(min_samples_leaf = 5 , ('Random Forest2' , RandomForestClassifier(min_samples_leaf = 5
# , n_estimators = 1000 , n_estimators = 1000
# , bootstrap = True , bootstrap = True
# , oob_score = True , oob_score = True
# , **njobs , **njobs
# , **rs , **rs
# , max_features = 'auto') ) , max_features = 'auto') )
# , ('Ridge Classifier' , RidgeClassifier(**rs) ) , ('Ridge Classifier' , RidgeClassifier(**rs) )
# , ('Ridge ClassifierCV' , RidgeClassifierCV(cv = 3) ) , ('Ridge ClassifierCV' , RidgeClassifierCV(cv = 3) )
# , ('SVC' , SVC(**rs) ) , ('SVC' , SVC(**rs) )
# , ('Stochastic GDescent' , SGDClassifier(**rs, **njobs) ) , ('Stochastic GDescent' , SGDClassifier(**rs, **njobs) )
, ('XGBoost' , XGBClassifier(**rs, verbosity = 0, use_label_encoder =False, **njobs) ) , ('XGBoost' , XGBClassifier(**rs, verbosity = 0, use_label_encoder =False, **njobs) )
#
] ]
mm_skf_scoresD = {} mm_skf_scoresD = {}

View file

@ -45,10 +45,13 @@ spl_type = '70_30'
#spl_type = '80_20' #spl_type = '80_20'
#spl_type = 'sl' #spl_type = 'sl'
#data_type = "actual"
data_type = "complete"
df2 = split_tts(df df2 = split_tts(df
, data_type = 'actual' , data_type = data_type
, split_type = spl_type , split_type = spl_type
, oversampling = False , oversampling = True
, dst_colname = 'dst' , dst_colname = 'dst'
, target_colname = 'dst_mode' , target_colname = 'dst_mode'
, include_gene_name = True , include_gene_name = True
@ -67,8 +70,8 @@ Counter(df2['y'])
Counter(df2['y_bts']) Counter(df2['y_bts'])
fooD = MultModelsCl(input_df = df2['X'] fooD = MultModelsCl(input_df = df2['X_ros']
, target = df2['y'] , target = df2['y_ros']
, sel_cv = skf_cv , sel_cv = skf_cv
, run_blind_test = True , run_blind_test = True
, blind_test_df = df2['X_bts'] , blind_test_df = df2['X_bts']
@ -83,7 +86,7 @@ fooD = MultModelsCl(input_df = df2['X']
for k, v in fooD.items(): for k, v in fooD.items():
print('\nModel:', k print('\nModel:', k
, '\nTRAIN MCC:', fooD[k]['test_mcc'] , '\nTRAIN MCC:', fooD[k]['test_mcc']
, '\nBTS MCC:' , fooD[k]['bts_mcc'] , '\nBTS MCC:' , fooD[k]['bts_mcc']
, '\nDIFF:',fooD[k]['bts_mcc'] - fooD[k]['test_mcc'] ) , '\nDIFF:',fooD[k]['bts_mcc'] - fooD[k]['test_mcc'] )
#%% CHECK SCALING #%% CHECK SCALING

View file

@ -25,7 +25,7 @@ from GetMLData import *
from SplitTTS import * from SplitTTS import *
# param dict for getmldata() # param dict for getmldata()
combined_model_paramD = {'data_combined_model' : False gene_model_paramD = {'data_combined_model' : False
, 'use_or' : False , 'use_or' : False
, 'omit_all_genomic_features': False , 'omit_all_genomic_features': False
, 'write_maskfile' : False , 'write_maskfile' : False
@ -48,7 +48,7 @@ for gene, drug in ml_gene_drugD.items():
, '\nDrug:', drug) , '\nDrug:', drug)
gene_low = gene.lower() gene_low = gene.lower()
gene_dataD[gene_low] = getmldata(gene, drug gene_dataD[gene_low] = getmldata(gene, drug
, **combined_model_paramD) , **gene_model_paramD)
for split_type in split_types: for split_type in split_types:
for data_type in split_data_types: for data_type in split_data_types: