#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Wed Jun 29 20:29:36 2022 @author: tanu """ import sys, os import pandas as pd import numpy as np import re ############################################################################### homedir = os.path.expanduser("~") sys.path.append(homedir + '/git/LSHTM_analysis/scripts/ml/ml_functions') sys.path ############################################################################### outdir = homedir + '/git/LSHTM_ML/output/genes/' #==================== # Import ML functions #==================== from MultClfs import * from GetMLData import * from SplitTTS import * skf_cv = StratifiedKFold(n_splits = 10 #, shuffle = False, random_state= None) , shuffle = True, random_state = 42) #rskf_cv = RepeatedStratifiedKFold(n_splits = 10 # , n_repeats = 3 # , **rs) # param dict for getmldata() gene_model_paramD = {'data_combined_model' : False , 'use_or' : False , 'omit_all_genomic_features': False , 'write_maskfile' : False , 'write_outfile' : False } ############################################################################### #ml_genes = ["pncA", "embB", "katG", "rpoB", "gid"] ml_gene_drugD = { 'pncA' : 'pyrazinamide' , 'embB' : 'ethambutol' , 'katG' : 'isoniazid' , 'rpoB' : 'rifampicin' , 'gid' : 'streptomycin' } gene_dataD={} split_types = [ '70_30', '80_20', 'sl', 'rt', 'none_bts' ] split_data_types = [ #'actual', 'complete' ] for gene, drug in ml_gene_drugD.items(): print ('\nGene:', gene , '\nDrug:', drug) gene_low = gene.lower() gene_dataD[gene_low] = getmldata(gene, drug , **gene_model_paramD) for split_type in split_types: for data_type in split_data_types: out_filename = outdir + gene.lower() + '_' + split_type + '_' + data_type + '.csv' tempD=split_tts(gene_dataD[gene_low] , data_type = data_type , split_type = split_type , oversampling = True , dst_colname = 'dst' , target_colname = 'dst_mode' , include_gene_name = True ) paramD = { 'baseline_paramD': { 'input_df' : tempD['X'] , 'target' : tempD['y'] , 'var_type' : 'mixed' , 'resampling_type': 'none'} , 'smnc_paramD' : { 'input_df' : tempD['X_smnc'] , 'target' : tempD['y_smnc'] , 'var_type' : 'mixed' , 'resampling_type' : 'smnc'} , 'ros_paramD' : { 'input_df' : tempD['X_ros'] , 'target' : tempD['y_ros'] , 'var_type' : 'mixed' , 'resampling_type' : 'ros'} , 'rus_paramD' : { 'input_df' : tempD['X_rus'] , 'target' : tempD['y_rus'] , 'var_type' : 'mixed' , 'resampling_type' : 'rus'} , 'rouC_paramD' : { 'input_df' : tempD['X_rouC'] , 'target' : tempD['y_rouC'] , 'var_type' : 'mixed' , 'resampling_type' : 'rouC'} } mmDD = {} for k, v in paramD.items(): scoresD = MultModelsCl(**paramD[k] , sel_cv = skf_cv , tts_split_type = split_type , add_cm = True , add_yn = True , scale_numeric = ['min_max'] , run_blind_test = True , blind_test_df = tempD['X_bts'] , blind_test_target = tempD['y_bts'] , return_formatted_output = True , random_state = 42 , n_jobs = os.cpu_count() ) mmDD[k] = scoresD # Extracting the dfs from within the dict and concatenating to output as one df for k, v in mmDD.items(): out_wf= pd.concat(mmDD, ignore_index = True) out_wf_f = out_wf.sort_values(by = ['resampling', 'source_data', 'MCC'], ascending = [True, True, False], inplace = False) out_wf_f.to_csv(out_filename, index = False)