Skip to content

Commit

Permalink
Add initialization from dendrogram on clustered data
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed May 12, 2024
1 parent 875781d commit 20d0b75
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 11 deletions.
89 changes: 86 additions & 3 deletions pyscicone/scicone/scicone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pandas as pd
import phenograph
from scipy.cluster.hierarchy import linkage
from collections import Counter

class SCICoNE(object):
Expand Down Expand Up @@ -373,6 +374,80 @@ def learn_tree_parallel(self, segmented_data, segmented_region_sizes, new_postfi

return best_tree, robustness_score, trees

def initialize_cluster_tree(self, data=None, segmented_region_sizes=None, min_cluster_size=1, **kwargs):
if segmented_region_sizes is None:
segmented_region_sizes = self.bps['segmented_region_sizes']
if data is None:
data = self.data['filtered_counts']

if data.shape[1] != segmented_region_sizes.shape[0]:
print('Condensing regions...')
segmented_data = self.condense_regions(data, segmented_region_sizes)
print('Done.')
else:
segmented_data = data

if "region_neutral_states" in kwargs:
region_neutral_states = np.array(kwargs['region_neutral_states'])

if np.any(region_neutral_states) < 0:
raise Exception("Neutral states can not be negative!")

# If there is a region with neutral state = 0, remove it to facilitate tree inference
zero_neutral_regions = np.where(region_neutral_states==0)[0]
if len(zero_neutral_regions) > 0:
full_segmented_region_sizes = segmented_region_sizes.astype(int)
full_segmented_data = segmented_data
full_region_neutral_states = region_neutral_states
segmented_data = np.delete(segmented_data, zero_neutral_regions, axis=1)
segmented_region_sizes = np.delete(segmented_region_sizes, zero_neutral_regions)
region_neutral_states = np.delete(region_neutral_states, zero_neutral_regions)

kwargs['region_neutral_states'] = region_neutral_states
else:
region_neutral_states = np.ones((segmented_region_sizes.shape[0],)) * 2 # assume diploid

# Get the average read counts
clustered_segmented_data, cluster_sizes, cluster_assignments, Q = self.condense_segmented_clusters(segmented_data,
min_cluster_size=min_cluster_size)

# Normalize by region size
clustered_segmented_data = clustered_segmented_data/segmented_region_sizes[:,None]

# Center at 2
clustered_segmented_data = clustered_segmented_data/np.mean(clustered_segmented_data,axis=1)[:,None] * 2

# Round to integers
region_cnvs = np.round(clustered_segmented_data)

# Get events wrt root
region_events = region_cnvs - region_neutral_states

# Put the clusters on a tree as direct children of the root
tree = Tree(self.inference_binary, self.output_path)
tree.node_dict['0'] = dict(parent_id='NULL', region_event_dict={})
for cluster_idx, cluster_region_events in enumerate(region_events):
region_event_dict = {}
for region, event in enumerate(cluster_region_events):
if event != 0:
region_event_dict[region] = int(event)
tree.node_dict[str(int(cluster_idx)+1)] = dict(parent_id='0', region_event_dict=region_event_dict)

# Get dendrogram of clusters
Z = linkage(region_cnvs, 'ward')
Z[:,:2] += 1 # because root is 0
n = Z.shape[0]+1
# Go through the dendrogram and extract common ancestors
for nodepair_idx in range(n-1):
node1 = str(int(Z[nodepair_idx][0]))
node2 = str(int(Z[nodepair_idx][1]))
new_node_id = str(int(n+nodepair_idx+1))
tree.extract_common_ancestor([node1, node2], new_node_id)

tree.update_tree_str()

return tree, clustered_segmented_data, cluster_sizes, cluster_assignments

def learn_tree(self, data=None, segmented_region_sizes=None, n_reps=10, cluster=True, full=True, cluster_tree_n_iters=4000, nu_tree_n_iters=4000, full_tree_n_iters=4000, max_tries=2, robustness_thr=0.5, min_cluster_size=1, **kwargs):
if segmented_region_sizes is None:
segmented_region_sizes = self.bps['segmented_region_sizes']
Expand Down Expand Up @@ -417,6 +492,11 @@ def learn_tree(self, data=None, segmented_region_sizes=None, n_reps=10, cluster=
if cnt >= max_tries:
break
nu = tree.nu if tree is not None else 1.0

if tree is not None:
print("Running inference from this initial tree:")
print(tree.tree_str)

tree, robustness_score, trees = self.learn_tree_parallel(clustered_segmented_data, segmented_region_sizes, new_postfix=f"try{cnt}", n_reps=n_reps, nu=nu, cluster_sizes=cluster_sizes, initial_tree=tree, n_iters=cluster_tree_n_iters, verbose=self.verbose, **kwargs)
cnt += 1

Expand Down Expand Up @@ -475,14 +555,17 @@ def learn_tree(self, data=None, segmented_region_sizes=None, n_reps=10, cluster=
print('Initializing nu for full tree.')
# Update the nu on the full data (the nu on the clustered data is very different) with this tree
nu = tree.nu
move_probs = kwargs['move_probs']
kwargs.pop('move_probs', None)
move_probs = None
if 'move_probs' in kwargs:
move_probs = kwargs['move_probs']
kwargs.pop('move_probs', None)
tree = self.learn_single_tree(segmented_data, segmented_region_sizes, nu=nu, initial_tree=tree, n_iters=nu_tree_n_iters, move_probs=[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0], postfix=f"nu_tree_{self.postfix}", verbose=self.verbose, **kwargs)
print('Done. Will start from nu={}'.format(tree.nu))
print('Learning full tree...')
cnt = 0
robustness_score = 0.
kwargs['move_probs'] = move_probs
if move_probs is not None:
kwargs['move_probs'] = move_probs
while robustness_score < robustness_thr:
if cnt >= max_tries:
break
Expand Down
75 changes: 67 additions & 8 deletions pyscicone/scicone/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scicone.constants import *

class Tree(object):
def __init__(self, binary_path, output_path, postfix='PYSCICONETREETEMP', persistence=False, ploidy=2, copy_number_limit=6):
def __init__(self, binary_path, output_path, postfix='PYSCICONETREETEMP', persistence=False, ploidy=2, copy_number_limit=6, n_bins=-1):
self.binary_path = binary_path

self.output_path = output_path
Expand Down Expand Up @@ -43,6 +43,60 @@ def __init__(self, binary_path, output_path, postfix='PYSCICONETREETEMP', persis
self.cell_node_labels = []
self.num_labels = True

self.n_bins = n_bins

def get_event_intersection(self, nodes):
intersection = {}
anchor_node = nodes[0]
other_nodes = nodes[1:]

for key, value in self.node_dict[anchor_node]['region_event_dict'].items():
direction = 1 if value >= 0 else -1
minimum = abs(value)
has_intersection = True
for node in other_nodes:
if key in self.node_dict[node]['region_event_dict']:
if (self.node_dict[node]['region_event_dict'][key] >= 0) == (value >= 0):
minimum = min(minimum, abs(self.node_dict[node]['region_event_dict'][key]))
else:
has_intersection = False
break
else:
has_intersection = False
break

if has_intersection:
if value >= 0:
direction = 1
else:
direction = -1
intersection[key] = direction * minimum

return intersection

def extract_common_ancestor(self, node_pair, new_node_id=None):
parent_node = self.node_dict[node_pair[0]]['parent_id']
intersection = self.get_event_intersection(node_pair)

# Remove intersection from nodes in node_pair
for key, val in intersection.items():
for node in node_pair:
self.node_dict[node]['region_event_dict'][key] = self.node_dict[node]['region_event_dict'][key] - val
if self.node_dict[node]['region_event_dict'][key] == 0:
del self.node_dict[node]['region_event_dict'][key]

# Add node below parent
if new_node_id is None:
new_node_id = str(np.max(np.array(list(self.node_dict.keys())).astype(int)) + 1)

self.node_dict[new_node_id] = dict(parent_id=parent_node,
region_event_dict=intersection)

# Update parent
for node in node_pair:
self.node_dict[node]['parent_id'] = new_node_id


def get_n_children_per_node(self):
n_children = dict()
for node_id in self.node_dict:
Expand Down Expand Up @@ -86,8 +140,13 @@ def get_node_malignancies(self):

def set_node_cnvs(self):
# Set root state
n_bins = np.sum(self.outputs['region_sizes'].astype(int))
self.node_dict['0']['cnv'] = np.ones(n_bins,)
n_bins = self.n_bins
if n_bins == -1:
n_bins = np.sum(self.outputs['region_sizes'].astype(int))
else:
if np.sum(self.outputs['region_sizes'].astype(int)) == n_bins - 1:
self.outputs['region_sizes'][-1] += 1
self.node_dict['0']['cnv'] = 2*np.ones(n_bins,)
bin_start = 0
bin_end = 0
for region, state in enumerate(self.outputs['region_neutral_states']):
Expand Down Expand Up @@ -230,12 +289,12 @@ def create_cell_node_ids(self):
self.outputs['cell_node_ids'] = np.zeros((n_cells, 2))
self.outputs['cell_node_ids'][:,0] = np.arange(n_cells)
for node in nodes:
idx = np.where(np.all(self.outputs['inferred_cnvs'] == self.node_dict[node]['cnv'], axis=1))
self.outputs['cell_node_ids'][idx] = node
idx = np.where(np.all(self.outputs['inferred_cnvs'] == self.node_dict[node]['cnv'], axis=1))[0]
self.outputs['cell_node_ids'][idx,1] = node

def learn_tree(self, segmented_data, segmented_region_sizes, n_iters=1000, move_probs=[0.0,1.0,0.0,1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.01, 0.1, 0.01, 1.0, 0.01],
n_nodes=3, seed=42, postfix="", initial_tree=None, nu=1.0, cluster_sizes=None, region_neutral_states=None, alpha=0., gamma=1., max_scoring=True, copy_number_limit=2,
c_penalise=10.0, lambda_r=0.2, lambda_c=0.1, ploidy=2, verbosity=2, verbose=False, num_labels=False):
c_penalise=10.0, lambda_r=0.2, lambda_c=0.1, ploidy=2, eta=1e-4, verbosity=2, verbose=False, num_labels=False):
if postfix == "":
postfix = self.postfix

Expand Down Expand Up @@ -283,7 +342,7 @@ def learn_tree(self, segmented_data, segmented_region_sizes, n_iters=1000, move_
f"--move_probs={move_probs_str}", f"--seed={seed}", f"--region_sizes_file={temp_segmented_region_sizes_file}",\
f"--tree_file={temp_tree_file}", f"--nu={nu}", f"--cluster_sizes_file={temp_cluster_sizes_file}", f"--alpha={alpha}",\
f"--max_scoring={max_scoring}", f"--c_penalise={c_penalise}", f"--lambda_r={lambda_r}", f"--gamma={gamma}",\
f"--lambda_c={lambda_c}", f"--region_neutral_states_file={temp_region_neutral_states_file}"]
f"--lambda_c={lambda_c}", f"--eta={eta}", f"--region_neutral_states_file={temp_region_neutral_states_file}"]
if verbose:
print(' '.join(cmd))
cmd_output = subprocess.run(cmd)
Expand All @@ -306,7 +365,7 @@ def learn_tree(self, segmented_data, segmented_region_sizes, n_iters=1000, move_
f"--move_probs={move_probs_str}", f"--seed={seed}", f"--region_sizes_file={temp_segmented_region_sizes_file}",\
f"--nu={nu}", f"--cluster_sizes_file={temp_cluster_sizes_file}", f"--alpha={alpha}", f"--max_scoring={max_scoring}",\
f"--c_penalise={c_penalise}", f"--lambda_r={lambda_r}", f"--lambda_c={lambda_c}", f"--gamma={gamma}",\
f"--region_neutral_states_file={temp_region_neutral_states_file}"]
f"--eta={eta}", f"--region_neutral_states_file={temp_region_neutral_states_file}"]
if verbose:
print(' '.join(cmd))
cmd_output = subprocess.run(cmd)
Expand Down

0 comments on commit 20d0b75

Please sign in to comment.