From e5f882841e0fb39f12023508e17f6ca23d26bacd Mon Sep 17 00:00:00 2001 From: Tanushree Tunstall Date: Sat, 2 Jul 2022 16:57:41 +0100 Subject: [PATCH] added cm_datai.py to get data for cm model for running fs later --- scripts/ml/combined_model/cm_datai.py | 136 ++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 scripts/ml/combined_model/cm_datai.py diff --git a/scripts/ml/combined_model/cm_datai.py b/scripts/ml/combined_model/cm_datai.py new file mode 100644 index 0000000..b41e980 --- /dev/null +++ b/scripts/ml/combined_model/cm_datai.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jun 29 19:44:06 2022 + +@author: tanu +""" +import sys, os +import pandas as pd +import numpy as np +import re +from copy import deepcopy +from sklearn import linear_model +from sklearn import datasets +from collections import Counter +############################################################################### +homedir = os.path.expanduser("~") +sys.path.append(homedir + '/git/LSHTM_analysis/scripts/ml/ml_functions') +sys.path +############################################################################### +outdir = homedir + '/git/LSHTM_ML/output/combined/' + +#==================== +# Import ML functions +#==================== +from ml_data_combined import * +from MultClfs_logo_skf import * +#from GetMLData import * +#from SplitTTS import * + +skf_cv = StratifiedKFold(n_splits = 10 , shuffle = True, random_state = 42) + +#logo = LeaveOneGroupOut() + +######################################################################## +# COMPLETE data: No tts_split +######################################################################## +#%% +def CMLogoData(cm_input_df = pd.DataFrame() + , all_genes = ["embb", "katg", "rpob", "pnca", "gid", "alr"] + , bts_genes = ["embb", "katg" + , "rpob", "pnca", "gid" + ] + , cols_to_drop = ['dst', 'dst_mode', 'gene_name'] + , target_var = 'dst_mode' + , gene_group = 'gene_name' + , std_gene_omit = [] + ): + + cm_dataD = {} + for bts_gene in bts_genes: + print('\n BTS gene:', bts_gene) + + if not std_gene_omit: + training_genesL = ['alr'] + else: + training_genesL = [] + + tr_gene_omit = std_gene_omit + [bts_gene] + n_tr_genes = (len(bts_genes) - (len(std_gene_omit))) + #n_total_genes = (len(bts_genes) - len(std_gene_omit)) + n_total_genes = len(all_genes) + + training_genesL = training_genesL + list(set(bts_genes) - set(tr_gene_omit)) + #training_genesL = [element for element in bts_genes if element not in tr_gene_omit] + + print('\nTotal genes: ', n_total_genes + ,'\nTraining on:', n_tr_genes + ,'\nTraining on genes:', training_genesL + , '\nOmitted genes:', tr_gene_omit + , '\nBlind test gene:', bts_gene) + + tts_split_type = "logo_skf_BT_" + bts_gene + + outFile = outdir + str(n_tr_genes+1) + "genes_" + tts_split_type + ".csv" + + print(outFile) + + bts_geneD = {} + #------- + # training + #------ + cm_training_df = cm_input_df[~cm_input_df['gene_name'].isin(tr_gene_omit)] + + cm_X = cm_training_df.drop(cols_to_drop, axis=1, inplace=False) + #cm_y = cm_training_df.loc[:,'dst_mode'] + cm_y = cm_training_df.loc[:, target_var] + + gene_group = cm_training_df.loc[:,'gene_name'] + + print('\nTraining data dim:', cm_X.shape + , '\nTraining Target dim:', cm_y.shape) + + if all(cm_X.columns.isin(cols_to_drop) == False): + print('\nChecked training df does NOT have Target var') + else: + sys.exit('\nFAIL: training data contains Target var') + + #--------------- + # BTS: genes + #--------------- + cm_test_df = cm_input_df[cm_input_df['gene_name'].isin([bts_gene])] + + cm_bts_X = cm_test_df.drop(cols_to_drop, axis = 1, inplace = False) + #cm_bts_y = cm_test_df.loc[:, 'dst_mode'] + cm_bts_y = cm_test_df.loc[:, target_var] + + print('\nTEST data dim:', cm_bts_X.shape + , '\nTEST Target dim:', cm_bts_y.shape) + + + bts_geneD = {'cm_X' : cm_X + , 'cm_y' : cm_y + , 'cm_bts_X': cm_bts_X + , 'cm_bts_y': cm_bts_y} + + cm_dataD[bts_gene] = bts_geneD + + return(cm_dataD) + +#%% +df_complete_6g = CMLogoData(cm_input_df = combined_df, std_gene_omit=[] ) +df_complete_5g = CMLogoData(cm_input_df = combined_df, std_gene_omit=['alr']) +# checks + +len(df_complete_6g['embb']['cm_X']) +#len(df_complete_6g['embb']['cm_y']) + +len(df_complete_5g['embb']['cm_X']) +#len(df_complete_5g['embb']['cm_y']) + +df_actual_6g = CMLogoData(cm_input_df = combined_df_actual, std_gene_omit=[] ) +df_actual_5g = CMLogoData(cm_input_df = combined_df_actual, std_gene_omit=['alr']) + +len(df_actual_6g['embb']['cm_X']) +len(df_actual_5g['embb']['cm_X'])