Skip to content

Commit

Permalink
Fixed test failures with dask changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed Aug 26, 2024
1 parent 9139b14 commit b8f56f5
Showing 1 changed file with 71 additions and 68 deletions.
139 changes: 71 additions & 68 deletions etspy/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def run_alg(sino, iters, cfg, vol_geom, proj_geom):
return astra.data2d.get(rec_id)


def run_dart(sino, iters, dart_iters, p,
alg_id, proj_id, mask_id, rec_id, sino_id,
thresholds, gray_levels):
def run_dart(sino, iters, dart_iters, p, thresholds, gray_levels, cfg, vol_geom, proj_geom):
"""
Run discrete algebraic reoncsturction technique (DART) algorithm.
Expand All @@ -77,32 +75,32 @@ def run_dart(sino, iters, dart_iters, p,
Number of iterations for the DART reconstruction
p : float
Probability for free pixel determination
alg_id : int
ASTRA algorithm identity
proj_id : int
ASTRA projector identity
mask_id : boolean
ASTRA mask identity
rec_id : boolean
ASTRA reconstruction identity
sino_id : int
ASTRA sinogram identity
thresholds : list or NumPy array
Thresholds for DART reconstruction
gray_levels : list or NumPy array
Gray levels for DART reconstruction
cfg : dict
ASTRA algorithm configuration
vol_geom : dict
ASTRA volume geometry
proj_geom : dict
ASTRA projection geometry
Returns
----------
Numpy array
Reconstruction of input sinogram
"""
thickness, ny = astra.data2d.get(rec_id).shape
astra.data2d.store(sino_id, sino)
astra.data2d.store(rec_id, np.zeros([thickness, ny]))
astra.data2d.store(mask_id, np.ones([thickness, ny]))
proj_id = astra.create_projector("strip", proj_geom, vol_geom)
rec_id = astra.data2d.create("-vol", vol_geom)
sino_id = astra.data2d.create("-sino", proj_geom, sino)
mask_id = astra.data2d.create('-vol', vol_geom, 1)
cfg["ReconstructionDataId"] = rec_id
cfg["ProjectorId"] = proj_id
cfg["ProjectionDataId"] = sino_id
cfg["ReconstructionDataId"] = rec_id
alg_id = astra.algorithm.create(cfg)
astra.algorithm.run(alg_id, iters)
curr_rec = astra.data2d.get(rec_id)
dart_rec = copy.deepcopy(curr_rec)
Expand Down Expand Up @@ -229,10 +227,14 @@ def run(stack, method, niterations=20, constrain=None, thresh=0, cuda=None, thic
cfg["option"]["MinConstraint"] = thresh

elif method.lower() == "dart":
logger.info(
"Reconstructing with CUDA-accelerated DART algorithm (%s iterations)"
% niterations
)
cfg['type'] = 'SART_CUDA'
thresholds = [(gray_levels[i] + gray_levels[i + 1]) // 2 for i in range(len(gray_levels) - 1)]
mask = np.ones([thickness, ny])
mask_id = astra.data2d.create('-vol', vol_geom, mask)
cfg = astra.astra_dict('SART_CUDA')
cfg['option']['MinConstraint'] = 0
cfg['option']['MaxConstraint'] = 255
cfg['option']['ReconstructionMaskId'] = mask_id
Expand All @@ -253,83 +255,84 @@ def run(stack, method, niterations=20, constrain=None, thresh=0, cuda=None, thic
if method.lower() == "dart":
astra.data2d.store(mask_id, np.ones([thickness, ny]))
rec[i, :, :] = run_dart(stack.data[:, :, i], niterations, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
thresholds, gray_levels, cfg, vol_geom, proj_geom)
else:
astra.algorithm.run(alg, niterations)
rec[i, :, :] = astra.data2d.get(rec_id)
else:
if ncores is None:
ncores = min(nx, int(0.9 * mp.cpu_count()))

proj_id = astra.create_projector("strip", proj_geom, vol_geom)

if method.lower() == "fbp":
if method.lower() == 'fbp':
logger.info("Reconstructing with CPU-based FBP algorithm")
cfg = astra.astra_dict("FBP")
cfg["ProjectorId"] = proj_id
cfg["ProjectionDataId"] = sino_id
cfg["ReconstructionDataId"] = rec_id
cfg["option"] = {}
cfg['type'] = 'FBP'
cfg["option"]["FilterType"] = filter.lower()
niterations = 1
elif method.lower() == "sirt":
elif method.lower() == 'sirt':
logger.info("Reconstructing with CPU-based SIRT algorithm")
cfg = astra.astra_dict("SIRT")
cfg["ProjectorId"] = proj_id
cfg["ProjectionDataId"] = sino_id
cfg["ReconstructionDataId"] = rec_id
cfg['type'] = 'SIRT'
if constrain:
cfg["option"] = {}
cfg["option"]["MinConstraint"] = thresh
elif method.lower() == "sart":
elif method.lower() == 'sart':
logger.info("Reconstructing with CPU-based SART algorithm")
cfg = astra.astra_dict("SIRT")
cfg["ProjectorId"] = proj_id
cfg["ProjectionDataId"] = sino_id
cfg["ReconstructionDataId"] = rec_id
cfg['type'] = 'SART'
if constrain:
cfg["option"] = {}
cfg["option"]["MinConstraint"] = thresh
elif method.lower() == "dart":
logger.info("Reconstructing with CPU-based DART algorithm")
cfg['type'] = 'SART'
thresholds = [(gray_levels[i] + gray_levels[i + 1]) // 2 for i in range(len(gray_levels) - 1)]
mask = np.ones([thickness, ny])
mask_id = astra.data2d.create('-vol', vol_geom, mask)
cfg = astra.astra_dict('SART')
cfg["ProjectorId"] = proj_id
cfg['ProjectionDataId'] = sino_id
cfg['ReconstructionDataId'] = rec_id
cfg['option'] = {}
cfg['option']['MinConstraint'] = 0
cfg['option']['MaxConstraint'] = 255
cfg['option']['ReconstructionMaskId'] = mask_id

alg = astra.algorithm.create(cfg)

if method.lower() in ['fbp', 'sirt', 'sart']:
if ncores == 1:
for i in tqdm.tqdm(range(0, nx), disable=not (show_progressbar)):
rec[i] = run_alg(stack.data[:, :, i], niterations, sino_id, alg, rec_id)
tasks = [dask.delayed(run_alg)(stack.data[:, :, i], niterations, cfg,
vol_geom, proj_geom) for i in range(nx)]
if show_progressbar:
with ProgressBar():
results = dask.compute(*tasks, num_workers=ncores)
else:
logger.info("Using %s CPU cores to reconstruct %s slices" % (ncores, nx))
with mp.Pool(ncores) as pool:
for i, result in enumerate(
pool.starmap(run_alg,
[(stack.data[:, :, i], niterations, sino_id, alg, rec_id) for i in range(0, nx)],)):
rec[i] = result
results = dask.compute(*tasks, num_workers=ncores)

for i, result in enumerate(results):
rec[i] = result
# if ncores == 1:
# for i in tqdm.tqdm(range(0, nx), disable=not (show_progressbar)):
# rec[i] = run_alg(stack.data[:, :, i], niterations, sino_id, alg, rec_id)
# else:
# logger.info("Using %s CPU cores to reconstruct %s slices" % (ncores, nx))
# with mp.Pool(ncores) as pool:
# for i, result in enumerate(
# pool.starmap(run_alg,
# [(stack.data[:, :, i], niterations, sino_id, alg, rec_id) for i in range(0, nx)],)):
# rec[i] = result
elif method.lower() == 'dart':
if ncores == 1:
for i in tqdm.tqdm(range(0, nx), disable=not (show_progressbar)):
rec[i] = run_dart(stack.data[:, :, i], niterations, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
tasks = [dask.delayed(run_dart)(stack.data[:, :, i], niterations, dart_iterations, p,
thresholds, gray_levels, cfg, vol_geom, proj_geom) for i in range(nx)]
if show_progressbar:
with ProgressBar():
results = dask.compute(*tasks, num_workers=ncores)
else:
logger.info("Using %s CPU cores to reconstruct %s slices" % (ncores, nx))
with mp.Pool(ncores) as pool:
for i, result in enumerate(
pool.starmap(run_dart,
[(stack.data[:, :, i], niterations, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
for i in range(0, nx)],)):
rec[i] = result
results = dask.compute(*tasks, num_workers=ncores)

for i, result in enumerate(results):
rec[i] = result
# if ncores == 1:
# for i in tqdm.tqdm(range(0, nx), disable=not (show_progressbar)):
# rec[i] = run_dart(stack.data[:, :, i], niterations, dart_iterations, p,
# alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
# else:
# logger.info("Using %s CPU cores to reconstruct %s slices" % (ncores, nx))
# with mp.Pool(ncores) as pool:
# for i, result in enumerate(
# pool.starmap(run_dart,
# [(stack.data[:, :, i], niterations, dart_iterations, p,
# alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
# for i in range(0, nx)],)):
# rec[i] = result
astra.clear()
return rec

Expand Down

0 comments on commit b8f56f5

Please sign in to comment.