From 1910893828004035d3b48ca7b6debdc1fdab2b21 Mon Sep 17 00:00:00 2001 From: Jiayu Zhou <95116082+ZebYulon@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:21:48 -0500 Subject: [PATCH] enable multiple k --- MICA/mica_mds.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/MICA/mica_mds.py b/MICA/mica_mds.py index e4c1816..4b46e40 100644 --- a/MICA/mica_mds.py +++ b/MICA/mica_mds.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pathlib +import ast from sklearn.cluster import KMeans from MICA.lib import transform as trans from MICA.lib import mutual_info as mi @@ -45,6 +46,16 @@ def main(): mica_mds(args) +def parse_k(value): + try: + return [int(value)] + except ValueError: + try: + return ast.literal_eval(value) + except: + raise argparse.ArgumentTypeError(f"Invalid value for -k: {value}") + + def add_mds_arguments(parser): parser.add_argument('-dm', '--dr-method', metavar='STR', default='MDS', required=False, help='Transformation method used for dimension reduction ' @@ -56,7 +67,7 @@ def add_mds_arguments(parser): # parser.add_argument('-nc', '--num-clusters', metavar='INT', nargs='+', required=False, default=0, type=int, # help='Number of clusters to be specified in kmeans') - parser.add_argument('-nck', '--num-clusters-k', metavar='INT', default=4, required=False, type=int, + parser.add_argument('-nck', '--num-clusters-k', metavar='INT or List of INT', default=[4], required=False, type=parse_k, help='Number of clusters to be specified in kmeans') parser.add_argument('-le', '--louvain-enable', metavar='INT', required=False, default=0, help='enable knn-louvain clustering or not(0)', type=int) parser.add_argument('-nn', '--num-neighbors', metavar='INT', required=False, default=20, type=int, @@ -140,20 +151,21 @@ def mica_mds(args): logging.info('(cells, genes): {}'.format(frame.shape)) if not args.louvain_enable: - logging.info('Performing Kmeans clustering for # of {}...'.format(args.num_clusters_k)) - kms = KMeans(n_clusters=args.num_clusters_k) - kms.fit(mi_mds) - partition = pd.DataFrame(data=[i + 1 for i in list(kms.labels_)], index=frame.index, columns=["label"]) - - agg_embed = vs.visual_embed(partition, mi_mds, - args.output_dir, - suffix=args.num_clusters_k, - visual_method=args.visual_method, - num_works=args.num_workers, - min_dist=args.min_dist) - end = time.time() - runtime = end - start - logging.info('Done. Runtime: {} seconds'.format(runtime)) + for k in args.num_clusters_k: + logging.info('Performing Kmeans clustering for # of {}...'.format(k)) + kms = KMeans(n_clusters=k) + kms.fit(mi_mds) + partition = pd.DataFrame(data=[i + 1 for i in list(kms.labels_)], index=frame.index, columns=["label"]) + + agg_embed = vs.visual_embed(partition, mi_mds, + args.output_dir, + suffix=k, + visual_method=args.visual_method, + num_works=args.num_workers, + min_dist=args.min_dist) + end = time.time() + runtime = end - start + logging.info('Done. Runtime: {} seconds'.format(runtime)) else: logging.info('Performing KNN clustering ...')