diff --git a/pyscicone/scicone/scicone.py b/pyscicone/scicone/scicone.py index 39275c3..1125e7f 100644 --- a/pyscicone/scicone/scicone.py +++ b/pyscicone/scicone/scicone.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd import phenograph +from scipy.cluster.hierarchy import linkage from collections import Counter class SCICoNE(object): @@ -373,6 +374,80 @@ def learn_tree_parallel(self, segmented_data, segmented_region_sizes, new_postfi return best_tree, robustness_score, trees + def initialize_cluster_tree(self, data=None, segmented_region_sizes=None, min_cluster_size=1, **kwargs): + if segmented_region_sizes is None: + segmented_region_sizes = self.bps['segmented_region_sizes'] + if data is None: + data = self.data['filtered_counts'] + + if data.shape[1] != segmented_region_sizes.shape[0]: + print('Condensing regions...') + segmented_data = self.condense_regions(data, segmented_region_sizes) + print('Done.') + else: + segmented_data = data + + if "region_neutral_states" in kwargs: + region_neutral_states = np.array(kwargs['region_neutral_states']) + + if np.any(region_neutral_states) < 0: + raise Exception("Neutral states can not be negative!") + + # If there is a region with neutral state = 0, remove it to facilitate tree inference + zero_neutral_regions = np.where(region_neutral_states==0)[0] + if len(zero_neutral_regions) > 0: + full_segmented_region_sizes = segmented_region_sizes.astype(int) + full_segmented_data = segmented_data + full_region_neutral_states = region_neutral_states + segmented_data = np.delete(segmented_data, zero_neutral_regions, axis=1) + segmented_region_sizes = np.delete(segmented_region_sizes, zero_neutral_regions) + region_neutral_states = np.delete(region_neutral_states, zero_neutral_regions) + + kwargs['region_neutral_states'] = region_neutral_states + else: + region_neutral_states = np.ones((segmented_region_sizes.shape[0],)) * 2 # assume diploid + + # Get the average read counts + clustered_segmented_data, cluster_sizes, cluster_assignments, Q = self.condense_segmented_clusters(segmented_data, + min_cluster_size=min_cluster_size) + + # Normalize by region size + clustered_segmented_data = clustered_segmented_data/segmented_region_sizes[:,None] + + # Center at 2 + clustered_segmented_data = clustered_segmented_data/np.mean(clustered_segmented_data,axis=1)[:,None] * 2 + + # Round to integers + region_cnvs = np.round(clustered_segmented_data) + + # Get events wrt root + region_events = region_cnvs - region_neutral_states + + # Put the clusters on a tree as direct children of the root + tree = Tree(self.inference_binary, self.output_path) + tree.node_dict['0'] = dict(parent_id='NULL', region_event_dict={}) + for cluster_idx, cluster_region_events in enumerate(region_events): + region_event_dict = {} + for region, event in enumerate(cluster_region_events): + if event != 0: + region_event_dict[region] = int(event) + tree.node_dict[str(int(cluster_idx)+1)] = dict(parent_id='0', region_event_dict=region_event_dict) + + # Get dendrogram of clusters + Z = linkage(region_cnvs, 'ward') + Z[:,:2] += 1 # because root is 0 + n = Z.shape[0]+1 + # Go through the dendrogram and extract common ancestors + for nodepair_idx in range(n-1): + node1 = str(int(Z[nodepair_idx][0])) + node2 = str(int(Z[nodepair_idx][1])) + new_node_id = str(int(n+nodepair_idx+1)) + tree.extract_common_ancestor([node1, node2], new_node_id) + + tree.update_tree_str() + + return tree, clustered_segmented_data, cluster_sizes, cluster_assignments + def learn_tree(self, data=None, segmented_region_sizes=None, n_reps=10, cluster=True, full=True, cluster_tree_n_iters=4000, nu_tree_n_iters=4000, full_tree_n_iters=4000, max_tries=2, robustness_thr=0.5, min_cluster_size=1, **kwargs): if segmented_region_sizes is None: segmented_region_sizes = self.bps['segmented_region_sizes'] @@ -417,6 +492,11 @@ def learn_tree(self, data=None, segmented_region_sizes=None, n_reps=10, cluster= if cnt >= max_tries: break nu = tree.nu if tree is not None else 1.0 + + if tree is not None: + print("Running inference from this initial tree:") + print(tree.tree_str) + tree, robustness_score, trees = self.learn_tree_parallel(clustered_segmented_data, segmented_region_sizes, new_postfix=f"try{cnt}", n_reps=n_reps, nu=nu, cluster_sizes=cluster_sizes, initial_tree=tree, n_iters=cluster_tree_n_iters, verbose=self.verbose, **kwargs) cnt += 1 @@ -475,14 +555,17 @@ def learn_tree(self, data=None, segmented_region_sizes=None, n_reps=10, cluster= print('Initializing nu for full tree.') # Update the nu on the full data (the nu on the clustered data is very different) with this tree nu = tree.nu - move_probs = kwargs['move_probs'] - kwargs.pop('move_probs', None) + move_probs = None + if 'move_probs' in kwargs: + move_probs = kwargs['move_probs'] + kwargs.pop('move_probs', None) tree = self.learn_single_tree(segmented_data, segmented_region_sizes, nu=nu, initial_tree=tree, n_iters=nu_tree_n_iters, move_probs=[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0], postfix=f"nu_tree_{self.postfix}", verbose=self.verbose, **kwargs) print('Done. Will start from nu={}'.format(tree.nu)) print('Learning full tree...') cnt = 0 robustness_score = 0. - kwargs['move_probs'] = move_probs + if move_probs is not None: + kwargs['move_probs'] = move_probs while robustness_score < robustness_thr: if cnt >= max_tries: break diff --git a/pyscicone/scicone/tree.py b/pyscicone/scicone/tree.py index a79fe27..72b1513 100644 --- a/pyscicone/scicone/tree.py +++ b/pyscicone/scicone/tree.py @@ -9,7 +9,7 @@ from scicone.constants import * class Tree(object): - def __init__(self, binary_path, output_path, postfix='PYSCICONETREETEMP', persistence=False, ploidy=2, copy_number_limit=6): + def __init__(self, binary_path, output_path, postfix='PYSCICONETREETEMP', persistence=False, ploidy=2, copy_number_limit=6, n_bins=-1): self.binary_path = binary_path self.output_path = output_path @@ -43,6 +43,60 @@ def __init__(self, binary_path, output_path, postfix='PYSCICONETREETEMP', persis self.cell_node_labels = [] self.num_labels = True + self.n_bins = n_bins + + def get_event_intersection(self, nodes): + intersection = {} + anchor_node = nodes[0] + other_nodes = nodes[1:] + + for key, value in self.node_dict[anchor_node]['region_event_dict'].items(): + direction = 1 if value >= 0 else -1 + minimum = abs(value) + has_intersection = True + for node in other_nodes: + if key in self.node_dict[node]['region_event_dict']: + if (self.node_dict[node]['region_event_dict'][key] >= 0) == (value >= 0): + minimum = min(minimum, abs(self.node_dict[node]['region_event_dict'][key])) + else: + has_intersection = False + break + else: + has_intersection = False + break + + if has_intersection: + if value >= 0: + direction = 1 + else: + direction = -1 + intersection[key] = direction * minimum + + return intersection + + def extract_common_ancestor(self, node_pair, new_node_id=None): + parent_node = self.node_dict[node_pair[0]]['parent_id'] + intersection = self.get_event_intersection(node_pair) + + # Remove intersection from nodes in node_pair + for key, val in intersection.items(): + for node in node_pair: + self.node_dict[node]['region_event_dict'][key] = self.node_dict[node]['region_event_dict'][key] - val + if self.node_dict[node]['region_event_dict'][key] == 0: + del self.node_dict[node]['region_event_dict'][key] + + # Add node below parent + if new_node_id is None: + new_node_id = str(np.max(np.array(list(self.node_dict.keys())).astype(int)) + 1) + + self.node_dict[new_node_id] = dict(parent_id=parent_node, + region_event_dict=intersection) + + # Update parent + for node in node_pair: + self.node_dict[node]['parent_id'] = new_node_id + + def get_n_children_per_node(self): n_children = dict() for node_id in self.node_dict: @@ -86,8 +140,13 @@ def get_node_malignancies(self): def set_node_cnvs(self): # Set root state - n_bins = np.sum(self.outputs['region_sizes'].astype(int)) - self.node_dict['0']['cnv'] = np.ones(n_bins,) + n_bins = self.n_bins + if n_bins == -1: + n_bins = np.sum(self.outputs['region_sizes'].astype(int)) + else: + if np.sum(self.outputs['region_sizes'].astype(int)) == n_bins - 1: + self.outputs['region_sizes'][-1] += 1 + self.node_dict['0']['cnv'] = 2*np.ones(n_bins,) bin_start = 0 bin_end = 0 for region, state in enumerate(self.outputs['region_neutral_states']): @@ -230,12 +289,12 @@ def create_cell_node_ids(self): self.outputs['cell_node_ids'] = np.zeros((n_cells, 2)) self.outputs['cell_node_ids'][:,0] = np.arange(n_cells) for node in nodes: - idx = np.where(np.all(self.outputs['inferred_cnvs'] == self.node_dict[node]['cnv'], axis=1)) - self.outputs['cell_node_ids'][idx] = node + idx = np.where(np.all(self.outputs['inferred_cnvs'] == self.node_dict[node]['cnv'], axis=1))[0] + self.outputs['cell_node_ids'][idx,1] = node def learn_tree(self, segmented_data, segmented_region_sizes, n_iters=1000, move_probs=[0.0,1.0,0.0,1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.01, 0.1, 0.01, 1.0, 0.01], n_nodes=3, seed=42, postfix="", initial_tree=None, nu=1.0, cluster_sizes=None, region_neutral_states=None, alpha=0., gamma=1., max_scoring=True, copy_number_limit=2, - c_penalise=10.0, lambda_r=0.2, lambda_c=0.1, ploidy=2, verbosity=2, verbose=False, num_labels=False): + c_penalise=10.0, lambda_r=0.2, lambda_c=0.1, ploidy=2, eta=1e-4, verbosity=2, verbose=False, num_labels=False): if postfix == "": postfix = self.postfix @@ -283,7 +342,7 @@ def learn_tree(self, segmented_data, segmented_region_sizes, n_iters=1000, move_ f"--move_probs={move_probs_str}", f"--seed={seed}", f"--region_sizes_file={temp_segmented_region_sizes_file}",\ f"--tree_file={temp_tree_file}", f"--nu={nu}", f"--cluster_sizes_file={temp_cluster_sizes_file}", f"--alpha={alpha}",\ f"--max_scoring={max_scoring}", f"--c_penalise={c_penalise}", f"--lambda_r={lambda_r}", f"--gamma={gamma}",\ - f"--lambda_c={lambda_c}", f"--region_neutral_states_file={temp_region_neutral_states_file}"] + f"--lambda_c={lambda_c}", f"--eta={eta}", f"--region_neutral_states_file={temp_region_neutral_states_file}"] if verbose: print(' '.join(cmd)) cmd_output = subprocess.run(cmd) @@ -306,7 +365,7 @@ def learn_tree(self, segmented_data, segmented_region_sizes, n_iters=1000, move_ f"--move_probs={move_probs_str}", f"--seed={seed}", f"--region_sizes_file={temp_segmented_region_sizes_file}",\ f"--nu={nu}", f"--cluster_sizes_file={temp_cluster_sizes_file}", f"--alpha={alpha}", f"--max_scoring={max_scoring}",\ f"--c_penalise={c_penalise}", f"--lambda_r={lambda_r}", f"--lambda_c={lambda_c}", f"--gamma={gamma}",\ - f"--region_neutral_states_file={temp_region_neutral_states_file}"] + f"--eta={eta}", f"--region_neutral_states_file={temp_region_neutral_states_file}"] if verbose: print(' '.join(cmd)) cmd_output = subprocess.run(cmd)