Skip to content

Commit

Permalink
Simplified tilt_com alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed Oct 12, 2023
1 parent 256057f commit d8f7349
Showing 1 changed file with 57 additions and 83 deletions.
140 changes: 57 additions & 83 deletions tomotools/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import numpy as np
import copy
from scipy import optimize, ndimage
import pylab as plt
import warnings
import tqdm
from pystackreg import StackReg
from scipy.ndimage import center_of_mass
import logging
from numpy.fft import fft, fftshift, ifftshift, ifft
from skimage.registration import phase_cross_correlation as pcc
Expand Down Expand Up @@ -499,7 +496,7 @@ def align_stack(stack, method, start, show_progressbar, nslice, ratio,
return aligned


def tilt_com(stack, locs=None, interactive=False):
def tilt_com(stack, slices=None, nslices=None):
"""
Perform tilt axis alignment using center of mass (CoM) tracking.
Expand All @@ -522,92 +519,69 @@ def tilt_com(stack, locs=None, interactive=False):
def com_motion(theta, r, x0, z0):
return r - x0 * np.cos(theta) - z0 * np.sin(theta)

def get_coms(stack, nslice):
sino = stack.isig[nslice, :].deepcopy().data
coms = [center_of_mass(sino[i, :])[0] for i in range(0, sino.shape[0])]
return np.array(coms)
def get_best_slices(stack, nslices):
total_mass = stack.data.sum((0, 1))
mass_var = stack.data.sum(1).std(0)
mass_var[mass_var == 0] = 1e-5
ratio = (total_mass / mass_var)
locs = ratio.argsort()[::-1][0:nslices]
return locs

def get_coms(stack, slices):
sinos = stack.data[:, :, slices]
y = np.linspace(-int(sinos.shape[1] / 2), int(sinos.shape[1] / 2), sinos.shape[1], dtype='int')
total_mass = sinos.sum(1)
coms = np.sum(np.transpose(sinos, [0, 2, 1]) * y, 2) / total_mass
return coms

def fit_line(x, m, b):
return m * x + b

def shift_stack(stack, shifts):
shifted = stack.deepcopy()
for i in range(0, stack.data.shape[0]):
shifted.data[i, :, :] = ndimage.shift(stack.data[i, :, :],
[shifts[i], 0])
return shifted

def calc_shifts(stack, nslice):
# Convert tilts to rads
thetas = np.pi * stack.metadata.Tomography.tilts / 180.

# Calculate centers of mass for for each row of the sinogram
coms = get_coms(stack, nslice)

# Fit measured centers of mass with function describing
# expected motion of a cylinder
r, x0, z0 = optimize.curve_fit(com_motion, xdata=thetas,
ydata=coms, p0=[0, 0, 0])[0]

# Determine shifts to align centers of mass to cylindrical pathway
shifts = com_motion(thetas, r, x0, z0) - coms
return shifts, coms

def tilt_analyze(stack, slices):
# Convert tilts to rads
thetas = np.pi * stack.metadata.Tomography.tilts / 180.
r = np.zeros(len(slices))
x0 = np.zeros(len(slices))
z0 = np.zeros(len(slices))
for i in range(0, len(slices)):
coms = get_coms(stack, slices[i])
r[i], x0[i], z0[i] = optimize.curve_fit(com_motion,
xdata=thetas,
ydata=coms,
p0=[0, 0, 0])[0]
slope, intercept = optimize.curve_fit(fit_line,
xdata=r,
ydata=slices,
p0=[0, 0])[0]
tilt_shift = stack.data.shape[1] / 2\
- (stack.data.shape[1] / 2 - intercept)\
/ slope
rotation = 180 * np.arctan(1 / slope) / np.pi
return -tilt_shift, -rotation, r

data = stack.deepcopy()
if locs is None:
if interactive:
"""Prompt user for locations at which to fit the CoM"""
warnings.filterwarnings('ignore')
plt.figure(num='Align Tilt', frameon=False)
if len(data.data.shape) == 3:
plt.imshow(data.data[np.int(data.data.shape[0] / 2), :, :],
cmap='gray')
else:
plt.imshow(data, cmap='gray')
plt.title('Choose %s points for tilt axis alignment....' %
str(3))
coords = np.array(plt.ginput(3, timeout=0, show_clicks=True))
plt.close()
locs = np.int16(np.sort(coords[:, 0]))
else:
locs = np.int16(stack.data.shape[2] * np.array([0.33, 0.5, 0.67]))
logger.info("Performing alignments using slices: [%s, %s, %s]"
% (locs[0], locs[1], locs[2]))
else:
locs = np.int16(np.sort(locs))

if stack.metadata.Tomography.tilts is None:
raise ValueError("No tilts in stack.metadata.Tomography.")

shifts, coms = calc_shifts(stack, locs[1])
shifted = shift_stack(stack, shifts)
shifted.metadata.Tomography.shifts[:, 0] = \
shifted.metadata.Tomography.shifts[:, 0] - shifts
tilt_shift, tilt_rotation, r = tilt_analyze(stack, locs)

final = shifted.trans_stack(yshift=tilt_shift, angle=tilt_rotation)
if stack.data.shape[2] < 3:
raise ValueError("Dataset is only %s pixels in x dimension. This method cannot be used.")

if slices is None:
if nslices is None:
nx = stack.data.shape[2]
nslices = 0.1 * nx
if nslices < 3:
nslices = 3
elif nslices > 50:
nslices = 50
else:
if nslices > nx:
raise ValueError("nslices is greater than the X-dimension of the data.")
if nslices > 0.3 * nx:
nslices = int(0.3 * nx)
logger.warning("nslices is greater than 30%% of number of x pixels. Using %s slices instead." % nslices)
slices = get_best_slices(stack, nslices)
logger.info("Performing alignments using best %s slices" % nslices)
else:
slices = np.sort(slices)

coms = get_coms(stack, slices)

thetas = np.pi * stack.metadata.Tomography.tilts / 180.
r = np.zeros(len(slices))
x0 = np.zeros(len(slices))
z0 = np.zeros(len(slices))

for i in range(0, len(slices)):
r[i], x0[i], z0[i] = optimize.curve_fit(com_motion,
xdata=thetas,
ydata=coms[:, i],
p0=[0, 0, 0])[0]
slope, intercept = optimize.curve_fit(fit_line,
xdata=r,
ydata=slices,
p0=[0, 0])[0]
tilt_shift = (stack.data.shape[1] / 2 - intercept) / slope
tilt_rotation = -(180 * np.arctan(1 / slope) / np.pi)

final = stack.trans_stack(yshift=tilt_shift, angle=tilt_rotation)

logger.info("Calculated tilt-axis shift %.2f" % tilt_shift)
logger.info("Calculated tilt-axis rotation %.2f" % tilt_rotation)
Expand Down

0 comments on commit d8f7349

Please sign in to comment.