Skip to content

Commit

Permalink
Merge branch 'spacetelescope:main' into dark-current-noise-floor
Browse files Browse the repository at this point in the history
  • Loading branch information
mwregan2 authored Feb 5, 2024
2 parents 7c11971 + dfa9a50 commit 6155385
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ jump
- Added more allowable selections for the number of cores to use for
multiprocessing [#183].

- Fixed the computation of the number of rows per slice for multiprocessing,
which caused different results when running the step with multiprocess [#239]

ramp_fitting
~~~~~~~~~~~~

Expand Down
28 changes: 21 additions & 7 deletions src/stcal/jump/jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def detect_jumps(
err *= gain_2d
readnoise_2d *= gain_2d
# also apply to the after_jump thresholds
after_jump_flag_e1 = after_jump_flag_dn1 * gain_2d
after_jump_flag_e2 = after_jump_flag_dn2 * gain_2d
after_jump_flag_e1 = after_jump_flag_dn1 * np.nanmedian(gain_2d)
after_jump_flag_e2 = after_jump_flag_dn2 * np.nanmedian(gain_2d)

# Apply the 2-point difference method as a first pass
log.info("Executing two-point difference method")
Expand Down Expand Up @@ -279,6 +279,12 @@ def detect_jumps(
minimum_sigclip_groups=minimum_sigclip_groups,
only_use_ints=only_use_ints,
)
# remove redundant bits in pixels that have jump flagged but were
# already flagged as do_not_use or saturated.
gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']
# This is the flag that controls the flagging of snowballs.
if expand_large_events:
total_snowballs = flag_large_events(
Expand Down Expand Up @@ -316,7 +322,7 @@ def detect_jumps(
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers
else:
yinc = int(n_rows / n_slices)
yinc = int(n_rows // n_slices)
slices = []
# Slice up data, gdq, readnoise_2d into slices
# Each element of slices is a tuple of
Expand All @@ -325,17 +331,16 @@ def detect_jumps(

# must copy arrays here, find_crs will make copies but if slices
# are being passed in for multiprocessing then the original gdq will be
# modified unless copied beforehand
# modified unless copied beforehand.
gdq = gdq.copy()
data = data.copy()
copy_arrs = False # we don't need to copy arrays again in find_crs

for i in range(n_slices - 1):
slices.insert(
i,
(
data[:, :, i * yinc : (i + 1) * yinc, :],
gdq[:, :, i * yinc : (i + 1) * yinc, :],
gdq[:, :, i * yinc : (i + 1) * yinc, :].copy(),
readnoise_2d[i * yinc : (i + 1) * yinc, :],
rejection_thresh,
three_grp_thresh,
Expand All @@ -361,7 +366,7 @@ def detect_jumps(
n_slices - 1,
(
data[:, :, (n_slices - 1) * yinc : n_rows, :],
gdq[:, :, (n_slices - 1) * yinc : n_rows, :],
gdq[:, :, (n_slices - 1) * yinc : n_rows, :].copy() ,
readnoise_2d[(n_slices - 1) * yinc : n_rows, :],
rejection_thresh,
three_grp_thresh,
Expand All @@ -383,6 +388,8 @@ def detect_jumps(
)
log.info("Creating %d processes for jump detection ", n_slices)
pool = multiprocessing.Pool(processes=n_slices)
######### JUST FOR DEBUGGING #########################
# pool = multiprocessing.Pool(processes=1)
# Starts each slice in its own process. Starmap allows more than one
# parameter to be passed.
real_result = pool.starmap(twopt.find_crs, slices)
Expand Down Expand Up @@ -429,6 +436,13 @@ def detect_jumps(
# save the neighbors to be flagged that will be in the next slice
previous_row_above_gdq = row_above_gdq.copy()
k += 1
# remove redundant bits in pixels that have jump flagged but were
# already flagged as do_not_use or saturated.
gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']

# This is the flag that controls the flagging of snowballs.
if expand_large_events:
total_snowballs = flag_large_events(
Expand Down
12 changes: 4 additions & 8 deletions src/stcal/jump/twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,21 +401,17 @@ def find_crs(
# the transient seen after ramp jumps
flag_e_threshold = [after_jump_flag_e1, after_jump_flag_e2]
flag_groups = [after_jump_flag_n1, after_jump_flag_n2]

for cthres, cgroup in zip(flag_e_threshold, flag_groups):
if cgroup > 0:
if cgroup > 0 and cthres > 0:
cr_intg, cr_group, cr_row, cr_col = np.where(np.bitwise_and(gdq, jump_flag))
for j in range(len(cr_group)):
intg = cr_intg[j]
group = cr_group[j]
row = cr_row[j]
col = cr_col[j]
if e_jump_4d[intg, group - 1, row, col] >= cthres[row, col]:
for kk in range(group, min(group + cgroup + 1, ngroups)):
if (gdq[intg, kk, row, col] & sat_flag) == 0 and (
gdq[intg, kk, row, col] & dnu_flag
) == 0:
gdq[intg, kk, row, col] = np.bitwise_or(gdq[integ, kk, row, col], jump_flag)
if e_jump_4d[intg, group - 1, row, col] >= cthres:
for kk in range(group + 1, min(group + cgroup + 1, ngroups)):
gdq[intg, kk, row, col] = np.bitwise_or(gdq[intg, kk, row, col], jump_flag)
if "stddev" in locals():
return gdq, row_below_gdq, row_above_gdq, num_primary_crs, stddev

Expand Down
94 changes: 94 additions & 0 deletions tests/test_jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
find_faint_extended,
flag_large_events,
point_inside_ellipse,
detect_jumps,
)


Expand All @@ -31,6 +32,99 @@ def _cube(ngroups, readnoise=10):
return _cube


def test_multiprocessing():
nints = 1
nrows = 13
ncols = 2
ngroups = 13
readnoise = 10
frames_per_group = 1

data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 4
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
num_cores = "1"
data[0, 4:, 5, 1] = 2000
gdq[0, 4:, 6, 1] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
print(data[0, 4, :, :])
print(gdq[0, 4, :, :])
assert gdq[0, 4, 5, 1] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 6, 1] == DQFLAGS['DO_NOT_USE']

# This section of code will fail without the fixes for PR #239 that prevent
# the double flagging pixels with jump which already have do_not_use or saturation set.
num_cores = "5"
data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 3
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
data[0, 4:, 5, 1] = 2000
gdq[0, 4:, 6, 1] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
assert gdq[0, 4, 5, 1] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 6, 1] == DQFLAGS['DO_NOT_USE'] #This value would have been 5 without the fix.


def test_multiprocessing_big():
nints = 1
nrows = 2048
ncols = 7
ngroups = 13
readnoise = 10
frames_per_group = 1

data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 4
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
num_cores = "1"
data[0, 4:, 204, 5] = 2000
gdq[0, 4:, 204, 6] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
print(data[0, 4, :, :])
print(gdq[0, 4, :, :])
assert gdq[0, 4, 204, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 205, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 204, 6] == DQFLAGS['DO_NOT_USE']

# This section of code will fail without the fixes for PR #239 that prevent
# the double flagging pixels with jump which already have do_not_use or saturation set.
num_cores = "10"
data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 3
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
data[0, 4:, 204, 5] = 2000
gdq[0, 4:, 204, 6] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
assert gdq[0, 4, 204, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 205, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 204, 6] == DQFLAGS['DO_NOT_USE'] #This value would have been 5 without the fix.



def test_find_simple_ellipse():
plane = np.zeros(shape=(5, 5), dtype=np.uint8)
plane[2, 2] = DQFLAGS["JUMP_DET"]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def test_10grps_1cr_afterjump(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 0.0
after_jump_flag_e1 = 1.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down Expand Up @@ -891,7 +891,7 @@ def test_10grps_1cr_afterjump_2group(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 0.0
after_jump_flag_e1 = 1.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down Expand Up @@ -932,7 +932,7 @@ def test_10grps_1cr_afterjump_toosmall(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 10000.0
after_jump_flag_e1 = 10000.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down Expand Up @@ -968,8 +968,8 @@ def test_10grps_1cr_afterjump_twothresholds(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 500.0
after_jump_flag_e2 = np.full(data.shape[2:4], 1.0) * 10.0
after_jump_flag_e1 = 500.0
after_jump_flag_e2 = 10.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down

0 comments on commit 6155385

Please sign in to comment.