Skip to content

Commit

Permalink
working tests ... finally
Browse files Browse the repository at this point in the history
  • Loading branch information
mpound committed Nov 6, 2024
1 parent 1f18f0b commit 862469e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
14 changes: 8 additions & 6 deletions src/dysh/fits/gbtfitsload.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def __init__(self, fileobj, source=None, hdu=None, skipflags=False, **kwargs):
# https://docs.astropy.org/en/stable/api/astropy.units.cds.enable.html#astropy.units.cds.enable
# cds.enable() # to get mmHg

if kwargs.get("verbose", None):
# ushow/udata depend on the index being present, so check that index is created.
if kwargs.get("verbose", None) and kwargs_opts["index"]:
print("==GBTLoad %s" % fileobj)
self.ushow("OBJECT", 0)
self.ushow("SCAN", 0)
Expand All @@ -118,7 +119,10 @@ def __init__(self, fileobj, source=None, hdu=None, skipflags=False, **kwargs):
lsdf = len(self._sdf)
if lsdf > 1:
print(f"Loaded {lsdf} FITS files")
self.add_history(f"Project ID: {self.projectID}", add_time=True)
if kwargs_opts["index"]:
self.add_history(f"Project ID: {self.projectID}", add_time=True)
else:
print("Reminder: No index created; many functions won't work.")

def __repr__(self):
return str(self.files)
Expand All @@ -141,7 +145,6 @@ def projectID(self):
-------
str
The project ID string
"""
return uniq(self["PROJID"])[0]

Expand Down Expand Up @@ -2392,7 +2395,6 @@ def write(
c = fits.Column(name="FLAGS", format=form, array=flagval)
self._sdf[k]._update_binary_table_column({"FLAGS": c})
ob = self._sdf[k]._bintable_from_rows(rows, b)
print(f"#### {set(ob.data['CTYPE2'])=} ####")
if len(ob.data) > 0:
outhdu.append(ob)
total_rows_written += lr
Expand All @@ -2416,7 +2418,7 @@ def write(
if verbose:
print(f"Total of {total_rows_written} rows written to files.")
else:
hdu = self._sdf[fi[0]]._hdu[fi[0]].copy()
hdu = self._sdf[fi[0]]._hdu[0].copy()
outhdu = fits.HDUList(hdu)
for k in fi:
df = select_from("FITSINDEX", k, _final)
Expand Down Expand Up @@ -2521,7 +2523,7 @@ def __setitem__(self, items, values):
col_exists = len(set(self.columns).intersection(iset)) > 0
# col_in_selection =
if col_exists:
warnings.warn("Changing an existing SDFITS column")
warnings.warn(f"Changing an existing SDFITS column {items}")
# now deal with values as arrays
is_array = False
if isinstance(values, (Sequence, np.ndarray)) and not isinstance(values, str):
Expand Down
51 changes: 28 additions & 23 deletions src/dysh/fits/sdfitsload.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def udata(self, key, bintable=None):
The unique set of values for the input keyword.
"""
if self._index is None:
raise ValueError("Can't retrieve keyword {key} because no index is present.")
if bintable is not None:
df = self._index[self._index["BINTABLE"] == bintable]
else:
Expand Down Expand Up @@ -658,10 +660,8 @@ def _bintable_from_rows(self, rows=None, bintable=None):
# ensure rows are sorted
rows.sort()
outbintable = self._bintable[bintable].copy()
# print(f"bintable copy data length {len(outbintable.data)}")
outbintable.data = outbintable.data[rows]
# print(f"bintable rows data length {len(outbintable.data)}")
outbintable.update()
outbintable.update_header()
return outbintable

def write(
Expand Down Expand Up @@ -855,7 +855,6 @@ def _add_binary_table_column(self, name, value, bintable=None):
"""
# If we pass the data through a astropy Table first, then the conversion of
# numpy array dtype to FITS format string (e.g, '12A') gets done automatically and correctly.
# print(f"_add_binary_table_column({name}, v={value}, bintable={bintable})")
is_col = isinstance(value, Column)
if is_col:
lenv = len(value.array)
Expand All @@ -869,7 +868,8 @@ def _add_binary_table_column(self, name, value, bintable=None):
if is_col:
self._bintable[bintable].columns.add_col(value)
else:
t = BinTableHDU(Table(names=[name], data=[value]))
t1 = Table(names=[name], data=[value])
t = BinTableHDU(t1)
self._bintable[bintable].columns.add_col(t.columns[name])
# self._update_column_added(bintable, self._bintable[bintable].columns)
# self._bintable[bintable].update_header()
Expand All @@ -888,14 +888,13 @@ def _add_binary_table_column(self, name, value, bintable=None):
)
self._bintable[i].columns.add_col(cut)
else:
t = BinTableHDU(Table(names=[name], data=value[start : start + n]))
t = BinTableHDU(Table(names=[name], data=[value[start : start + n]]))
self._bintable[i].columns.add_col(t.columns[name])
# self._update_column_added(i, self._bintable[i].columns)
# self._bintable[bintable].update_header()
start = start + n

def _update_column_added(self, bintable, coldefs):
print("UPDATE ", coldefs)
self.bintable[bintable].data = fits.fitsrec.FITS_rec.from_columns(
columns=coldefs,
nrows=self.bintable[bintable]._nrows,
Expand All @@ -915,7 +914,6 @@ def _update_binary_table_column(self, column_dict):
# BinTableHDU interface will take care of the types and data lengths matching.
# It will even allow the case where len(values) == 1 to replace all column values with
# a single value.
print(f"updating {column_dict}")
if len(self._bintable) == 1:
for k, v in column_dict.items():
is_col = isinstance(v, Column)
Expand All @@ -926,22 +924,26 @@ def _update_binary_table_column(self, column_dict):
# self._bintable[i].data is an astropy.io.fits.fitsrec.FITS_rec, with length equal
# to number of rows
is_str = isinstance(value, str)
if is_str or not isinstance(value, (Sequence, np.ndarray)):
# print(f"making an array for {k}")
if is_str or not isinstance(value, np.ndarray):
value = np.full(self._bintable[0]._nrows, value)
# NOTE: if k is from the primary header and not a data column
# then this test fails, and we will ADD a new binary data column.
# So the primary header and data column could be inconsistent.
# It actually happens in
# test_gbtfitsload.py:test_set_item for SITELONG=[-42.21]*g.total_rows
# Indeed there is a warning in _add_primary_hdu about this.
# The behavior is intended,the column takes precedence over
# the primary hdu, but we should document this publicly.
if k in self._bintable[0].data.names:
# have to assigned directly to array becaause somewhere
# deep in astropy a ref is kep to the original coldefs data which
# have to assigned directly to array because somewhere
# deep in astropy a ref is kept to the original coldefs data which
# gets recopied if a column is added.
# print(f"Setting {k} {type(value)=}")
self._bintable[0].data.columns[k].array = value
self._bintable[0].data[k] = value
# otherwise we need to add rather than replace/update
else:
# print(f"update calling add for {k}={value}")
self._add_binary_table_column(k, value, 0)
self._bintable[0].update_header()
print(f"#### {set(self._bintable[0].data['CTYPE2'])=} ####")
else:
start = 0
for k, v in column_dict.items():
Expand All @@ -950,11 +952,13 @@ def _update_binary_table_column(self, column_dict):
value = v.array
else:
value = v

is_str = isinstance(value, str)
if not is_str and isinstance(value, (Sequence, np.ndarray)) and len(value) != self.total_rows:
raise ValueError(
f"Length of values array ({len(v)}) for column {k} and total number of rows ({self.total_rows}) aren't equal."
)

# Split values up by length of the individual binary tables
for j in range(len(self._bintable)):
b = self._bintable[j]
Expand All @@ -964,24 +968,25 @@ def _update_binary_table_column(self, column_dict):
b.data.columns[k].array = value
b.data[k] = value
else:
b.data.columns[k].array
b.data.columns[k].array = value[start : start + n]
b.data[k] = value[start : start + n]
start = start + n
else:
v1 = value
if not is_str and isinstance(value, Sequence):
v1 = np.array(value)
else:
v1 = value
n = len(b.data)
if is_str or not isinstance(value, (Sequence, np.ndarray)):
# we have to make an array from value if the user
# did a single assignment for a column, e.g. sdf["TCAL"] = 3.
# Need a new variable here or multiple loops keep expanding value
v1 = np.full(n, value)
else:
v1 = value[start : start + n]
start = start + n
# print(f"BT{j} update calling add for {k}={v1}")
v1 = np.array(value[start : start + n])
start = start + n
self._add_binary_table_column(k, v1, j)
self._bintable[j].update_header()
print(f"#### {set(self._bintable[j].data['CTYPE2'])=} ####")

def __getitem__(self, items):
# items can be a single string or a list of strings.
Expand Down Expand Up @@ -1028,11 +1033,11 @@ def __setitem__(self, items, values):
iset = set(items)
col_exists = len(set(self.columns).intersection(iset)) > 0
if col_exists and "DATA" not in items:
warnings.warn("Changing an existing SDFITS column")
warnings.warn(f"Changing an existing SDFITS column {items}")
try:
self._update_binary_table_column(d)
except Exception as e:
raise Exception(f"Could not update SDFITS binary table because {e}")
raise Exception(f"Could not update SDFITS binary table for {items} because {e}")
# only update the index if the binary table could be updated.
# DATA is not in the index.
if "DATA" not in items:
Expand Down
7 changes: 2 additions & 5 deletions src/dysh/fits/tests/test_gbtfitsload.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def test_gettp(self):
8: {"SCAN": 6, "IFNUM": 2, "PLNUM": 0, "CAL": False, "SIG": True},
}
for k, v in tests.items():
print(f"{k}, {v}")
if v["SIG"] == False:
with pytest.raises(Exception):
tps = sdf.gettp(scan=v["SCAN"], ifnum=v["IFNUM"], plnum=v["PLNUM"], cal=v["CAL"], sig=v["SIG"])
Expand Down Expand Up @@ -480,7 +479,8 @@ def test_write_all(self, tmp_path):
d = tmp_path / "sub"
d.mkdir()
output = d / "test_write_all.fits"
org_sdf.write(output, overwrite=True)
# don't write flags to avoid TDIM84 new column
org_sdf.write(output, overwrite=True, flags=False)
new_sdf = gbtfitsload.GBTFITSLoad(output)
# Compare the index for both SDFITS.
# Note we now auto-add a HISTORY card at instantiation, so drop that
Expand Down Expand Up @@ -622,9 +622,6 @@ def test_azel_coords(self, tmp_path):
sdf_org["RADESYS"] = ""
sdf_org["CTYPE2"] = "AZ"
sdf_org["CTYPE3"] = "EL"
print(f"{sdf_org['RADESYS'][0]=}")
print(f"{sdf_org['CTYPE2'][0]=}")
print(f"{sdf_org['CTYPE3'][0]=}")

# Create a temporary directory and write the modified SDFITS.
new_path = tmp_path / "o"
Expand Down

0 comments on commit 862469e

Please sign in to comment.