Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 72 additions & 38 deletions compass/landice/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mpas_tools.scrip.from_mpas import scrip_from_mpas
from netCDF4 import Dataset
from scipy import ndimage
from scipy.interpolate import NearestNDInterpolator, interpn
from scipy.interpolate import interpn
from scipy.ndimage import distance_transform_edt


Expand Down Expand Up @@ -928,8 +928,7 @@ def add_bedmachine_thk_to_ais_gridded_data(self, source_gridded_dataset,
return gridded_dataset_with_bm_thk


def preprocess_ais_data(self, source_gridded_dataset,
floodFillMask):
def preprocess_ais_data(self, source_gridded_dataset, floodFillMask):
"""
Perform adjustments to gridded AIS datasets needed
for rest of compass workflow to utilize them
Expand All @@ -947,13 +946,52 @@ def preprocess_ais_data(self, source_gridded_dataset,
preprocessed_gridded_dataset : str
name of NetCDF file with preprocessed version of gridded dataset
"""

logger = self.logger

def _nearest_fill_from_valid(field2d, valid_mask):
"""
Fill invalid cells in a 2D regular raster using the value from the
nearest valid cell on the same grid.

Parameters
----------
field2d : numpy.ndarray
2D field to be filled
valid_mask : numpy.ndarray
Boolean mask where True marks valid cells

Returns
-------
filled : numpy.ndarray
Copy of field2d with invalid cells filled
"""
valid_mask = np.asarray(valid_mask, dtype=bool)

if field2d.shape != valid_mask.shape:
raise ValueError('field2d and valid_mask must have the same shape')

if not np.any(valid_mask):
raise ValueError('No valid cells available for nearest fill.')

# For EDT, foreground=True cells get mapped to nearest background=False
# cell when return_indices=True. So we pass ~valid_mask.
nearest_inds = distance_transform_edt(
~valid_mask, return_distances=False, return_indices=True
)

filled = np.array(field2d, copy=True)
invalid = ~valid_mask
filled[invalid] = field2d[
nearest_inds[0, invalid],
nearest_inds[1, invalid]
]
return filled

# Apply floodFillMask to thickness field to help with culling
file_with_flood_fill = \
f"{source_gridded_dataset.split('.')[:-1][0]}_floodFillMask.nc"
copyfile(source_gridded_dataset, file_with_flood_fill)

gg = Dataset(file_with_flood_fill, 'r+')
gg.variables['thk'][0, :, :] *= floodFillMask
gg.variables['vx'][0, :, :] *= floodFillMask
Expand All @@ -963,65 +1001,62 @@ def preprocess_ais_data(self, source_gridded_dataset,
# Now deal with the peculiarities of the AIS dataset.
preprocessed_gridded_dataset = \
f"{file_with_flood_fill.split('.')[:-1][0]}_filledFields.nc"
copyfile(file_with_flood_fill,
preprocessed_gridded_dataset)
copyfile(file_with_flood_fill, preprocessed_gridded_dataset)

data = Dataset(preprocessed_gridded_dataset, 'r+')
data.set_auto_mask(False)

x1 = data.variables["x1"][:]
y1 = data.variables["y1"][:]
cellsWithIce = data.variables["thk"][:].ravel() > 0.

thk = data.variables["thk"][0, :, :]
cellsWithIce = thk > 0.0

data.createVariable('iceMask', 'f', ('time', 'y1', 'x1'))
data.variables['iceMask'][:] = data.variables["thk"][:] > 0.
data.variables['iceMask'][:] = data.variables["thk"][:] > 0.0

# Note: dhdt is only reported over grounded ice, so we will have to
# either update the dataset to include ice shelves or give them values of
# 0 with reasonably large uncertainties.
dHdt = data.variables["dhdt"][:]
dHdtErr = 0.05 * dHdt # assign arbitrary uncertainty of 5%
# Where dHdt data are missing, set large uncertainty
dHdtErr[dHdt > 1.e30] = 1.
dHdtErr[dHdt > 1.e30] = 1.0

# Extrapolate fields beyond region with ice to avoid interpolation
# artifacts of undefined values outside the ice domain
# Do this by creating a nearest neighbor interpolator of the valid data
# to recover the actual data within the ice domain and assign nearest
# neighbor values outside the ice domain
xGrid, yGrid = np.meshgrid(x1, y1)
xx = xGrid.ravel()
yy = yGrid.ravel()
# artifacts of undefined values outside the ice domain.
#
# The masks below are masks of valid cells.
bigTic = time.perf_counter()
for field in ['thk', 'bheatflx', 'vx', 'vy',
'ex', 'ey', 'thkerr', 'dhdt']:
tic = time.perf_counter()
logger.info(f"Beginning building interpolator for {field}")
logger.info(f'Beginning nearest-fill preprocessing for {field}')

field2d = data.variables[field][0, :, :]

if field in ['thk', 'thkerr']:
mask = cellsWithIce.ravel()
valid_mask = cellsWithIce
elif field == 'bheatflx':
mask = np.logical_and(
data.variables[field][:].ravel() < 1.0e9,
data.variables[field][:].ravel() != 0.0)
valid_mask = np.logical_and(field2d < 1.0e9, field2d != 0.0)
elif field in ['vx', 'vy', 'ex', 'ey', 'dhdt']:
mask = np.logical_and(
data.variables[field][:].ravel() < 1.0e9,
cellsWithIce.ravel() > 0)
valid_mask = np.logical_and(field2d < 1.0e9, cellsWithIce)
else:
mask = cellsWithIce
interp = NearestNDInterpolator(
list(zip(xx[mask], yy[mask])),
data.variables[field][:].ravel()[mask])
toc = time.perf_counter()
logger.info(f"Finished building interpolator in {toc - tic} seconds")
valid_mask = cellsWithIce

logger.info(f'{field}: {valid_mask.sum()} valid cells, '
f'{(~valid_mask).sum()} cells to fill')

filled2d = _nearest_fill_from_valid(field2d, valid_mask)
data.variables[field][0, :, :] = filled2d

tic = time.perf_counter()
logger.info(f"Beginning interpolation for {field}")
# NOTE: Do not need to evaluate the extrapolator at all grid cells.
# Only needed for ice-free grid cells, since is NN extrapolation
data.variables[field][0, :] = interp(xGrid, yGrid)
toc = time.perf_counter()
logger.info(f"Interpolation completed in {toc - tic} seconds")
logger.info(f'Nearest-fill preprocessing for {field} completed in '
f'{toc - tic:.3f} seconds')

bigToc = time.perf_counter()
logger.info(f"All interpolations completed in {bigToc - bigTic} seconds.")
logger.info(f'All nearest-fill preprocessing completed in '
f'{bigToc - bigTic:.3f} seconds.')

# Now perform some additional clean up adjustments to the dataset
data.createVariable('dHdtErr', 'f', ('time', 'y1', 'x1'))
Expand All @@ -1036,7 +1071,6 @@ def preprocess_ais_data(self, source_gridded_dataset,

data.variables['subm'][:] *= -1.0 # correct basal melting sign
data.variables['subm_ss'][:] *= -1.0

data.renameVariable('dhdt', 'dHdt')
data.renameVariable('thkerr', 'topgerr')

Expand Down
6 changes: 3 additions & 3 deletions compass/landice/tests/antarctica/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, test_case):
self.add_output_file(
filename=f'{self.mesh_filename[:-3]}_ismip6_regionMasks.nc')
self.add_input_file(
filename='antarctica_1km_2024_01_29.nc',
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to switch to a coarser dataset? Is that accounting for any of the speedup you reported?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was intentional, and no I don't think it accounts for the speedup I reported. This was inadvertently switched from 8km to 1km in this recent commit: 0b52238#diff-992c4799a18d79fe0e3d55859f4b5f326104e5af5e6b8a5b0459848aa90a84ca (and was missed because there were far too many changes in that PR). This just resets to the original 8km data set, but I have been careful to be consistent in which resolution I use between main and this branch.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my most recent testing, in which I definitely used the 8km source dataset, I get All nearest-fill preprocessing completed in 0.288 seconds for this branch and All interpolations completed in 11.316709512029774 seconds. on main. So still a ~2 order of magnitude speedup when using a coarse data set, but not so impactful when just cutting of 11 seconds instead of 4100 seconds :).

target='antarctica_1km_2024_01_29.nc',
filename='antarctica_8km_2024_01_29.nc',
target='antarctica_8km_2024_01_29.nc',
database='')

# no setup() method is needed
Expand All @@ -70,7 +70,7 @@ def run(self):
section_name = 'mesh'

# TODO: do we want to add this to the config file?
source_gridded_dataset = 'antarctica_1km_2024_01_29.nc'
source_gridded_dataset = 'antarctica_8km_2024_01_29.nc'

if bedmachine_dataset is not None:
bm_updated_gridded_dataset = (
Expand Down
Loading