Skip to content

Commit 861d9d1

Browse files
committed
updated get_layer_data for xarray
1 parent 8229a7b commit 861d9d1

1 file changed

Lines changed: 111 additions & 65 deletions

File tree

mhkit/river/io/d3d.py

Lines changed: 111 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -161,23 +161,23 @@ def _convert_time(
161161
# pylint: disable=too-many-branches
162162
# pylint: disable=too-many-statements
163163
def get_layer_data(
164-
data: netCDF4.Dataset,
164+
data: Union[netCDF4.Dataset, xr.Dataset],
165165
variable: str,
166166
layer_index: int = -1,
167167
time_index: int = -1,
168168
to_pandas: bool = True,
169169
) -> Union[pd.DataFrame, xr.Dataset]:
170170
"""
171-
Get variable data from the NetCDF4 object at a specified layer and timestep.
171+
Get variable data from the NetCDF4 or xarray object at a specified layer and timestep.
172172
If the data is 2D the layer_index is ignored.
173173
174174
Parameters
175175
----------
176-
data: NetCDF4 object
177-
A NetCDF4 object that contains spatial data, e.g. velocity or shear
176+
data: Union[netCDF4.Dataset, xr.Dataset]
177+
A NetCDF4 or xarray Dataset object that contains spatial data, e.g. velocity or shear
178178
stress, generated by running a Delft3D model.
179179
variable: string
180-
Delft3D outputs many vairables. The full list can be
180+
Delft3D outputs many variables. The full list can be
181181
found using "data.variables.keys()" in the console.
182182
layer_index: int
183183
An integer to pull out a layer from the dataset. 0 being closest
@@ -204,26 +204,35 @@ def get_layer_data(
204204
raise TypeError("layer_index must be an int")
205205

206206
if not isinstance(data, (netCDF4.Dataset, xr.Dataset)):
207-
raise TypeError("data must be NetCDF4 object or xarray Dataset")
207+
raise TypeError("data must be a NetCDF4 Dataset or xarray Dataset")
208208

209209
if variable not in data.variables.keys():
210210
raise ValueError("variable not recognized")
211211

212212
if not isinstance(to_pandas, bool):
213213
raise TypeError(f"to_pandas must be of type bool. Got: {type(to_pandas)}")
214214

215-
coords = str(data.variables[variable].coordinates).split()
216-
var = data.variables[variable][:]
217-
max_time_index = data["time"].shape[0] - 1 # to account for zero index
215+
if isinstance(data, netCDF4.Dataset):
216+
coords = str(data.variables[variable].coordinates).split()
217+
var = data.variables[variable][:]
218+
max_time_index = data.variables["time"].shape[0] - 1
219+
elif isinstance(data, xr.Dataset):
220+
coords = list(data[variable].coords)
221+
var = data[variable].values
222+
max_time_index = data["time"].shape[0] - 1
218223

219224
if abs(time_index) > max_time_index:
220225
raise ValueError(
221226
"time_index must be less than the absolute value of the "
222227
f"max time index {max_time_index}"
223228
)
224229

225-
x = np.ma.getdata(data.variables[coords[0]][:], False)
226-
y = np.ma.getdata(data.variables[coords[1]][:], False)
230+
if isinstance(data, netCDF4.Dataset):
231+
x = np.ma.getdata(data.variables[coords[0]][:], False)
232+
y = np.ma.getdata(data.variables[coords[1]][:], False)
233+
elif isinstance(data, xr.Dataset):
234+
x = data[coords[0]].values
235+
y = data[coords[1]].values
227236

228237
if isinstance(var[0][0], np.ma.core.MaskedArray):
229238
max_layer = len(var[0][0])
@@ -233,15 +242,14 @@ def get_layer_data(
233242

234243
v = np.ma.getdata(var[time_index, :, layer_index], False)
235244
dimensions = 3
236-
elif isinstance(var[0][0], xr.core.variable.Variable):
245+
elif isinstance(var[0][0], np.ndarray):
237246
max_layer = var[0][0].shape[0]
238247

239248
if abs(layer_index) > max_layer:
240249
raise ValueError(f"layer_index must be less than the max layer {max_layer}")
241250

242251
v = np.ma.getdata(var[time_index, :, layer_index], False)
243252
dimensions = 3
244-
245253
else:
246254
if not isinstance(var[0][0], np.float64):
247255
raise TypeError("data not recognized")
@@ -250,63 +258,97 @@ def get_layer_data(
250258
v = np.ma.getdata(var[time_index, :], False)
251259

252260
# waterdepth
253-
if "mesh2d" in variable:
254-
cords_to_layers = {
255-
"mesh2d_face_x mesh2d_face_y": {
256-
"name": "mesh2d_nLayers",
257-
"coords": data.variables["mesh2d_layer_sigma"][:],
258-
},
259-
"mesh2d_face_x mesh2d_face_y mesh2d_layer_sigma": {
260-
"name": "mesh2d_nLayers",
261-
"coords": data.variables["mesh2d_layer_sigma"][:],
262-
},
263-
"mesh2d_edge_x mesh2d_edge_y": {
264-
"name": "mesh2d_nInterfaces",
265-
"coords": data.variables["mesh2d_interface_sigma"][:],
266-
},
267-
}
268-
bottom_depth = np.ma.getdata(
269-
data.variables["mesh2d_waterdepth"][time_index, :], False
270-
)
271-
waterlevel = np.ma.getdata(data.variables["mesh2d_s1"][time_index, :], False)
272-
coords = str(data.variables["mesh2d_waterdepth"].coordinates).split()
273-
274-
elif str(data.variables[variable].coordinates) == "FlowElem_xcc FlowElem_ycc":
275-
cords_to_layers = {
276-
"FlowElem_xcc FlowElem_ycc": {
277-
"name": "laydim",
278-
"coords": data.variables["LayCoord_cc"][:],
279-
},
280-
"FlowLink_xu FlowLink_yu": {
281-
"name": "wdim",
282-
"coords": data.variables["LayCoord_w"][:],
283-
},
284-
}
285-
bottom_depth = np.ma.getdata(data.variables["waterdepth"][time_index, :], False)
286-
waterlevel = np.ma.getdata(data.variables["s1"][time_index, :], False)
287-
coords = str(data.variables["waterdepth"].coordinates).split()
288-
else:
289-
cords_to_layers = {
290-
"FlowElem_xcc FlowElem_ycc LayCoord_cc LayCoord_cc": {
291-
"name": "laydim",
292-
"coords": data.variables["LayCoord_cc"][:],
293-
},
294-
"FlowLink_xu FlowLink_yu": {
295-
"name": "wdim",
296-
"coords": data.variables["LayCoord_w"][:],
297-
},
298-
}
299-
bottom_depth = np.ma.getdata(data.variables["waterdepth"][time_index, :], False)
300-
waterlevel = np.ma.getdata(data.variables["s1"][time_index, :], False)
301-
coords = str(data.variables["waterdepth"].coordinates).split()
261+
if isinstance(data, netCDF4.Dataset):
262+
if "mesh2d" in variable:
263+
cords_to_layers = {
264+
"mesh2d_face_x mesh2d_face_y": {
265+
"name": "mesh2d_nLayers",
266+
"coords": data.variables["mesh2d_layer_sigma"][:],
267+
},
268+
"mesh2d_edge_x mesh2d_edge_y": {
269+
"name": "mesh2d_nInterfaces",
270+
"coords": data.variables["mesh2d_interface_sigma"][:],
271+
},
272+
}
273+
bottom_depth = np.ma.getdata(
274+
data.variables["mesh2d_waterdepth"][time_index, :], False
275+
)
276+
waterlevel = np.ma.getdata(data.variables["mesh2d_s1"][time_index, :], False)
277+
coords = str(data.variables["waterdepth"].coordinates).split()
278+
279+
elif str(data.variables[variable].coordinates) == "FlowElem_xcc FlowElem_ycc":
280+
cords_to_layers = {
281+
"FlowElem_xcc FlowElem_ycc": {
282+
"name": "laydim",
283+
"coords": data.variables["LayCoord_cc"][:],
284+
},
285+
"FlowLink_xu FlowLink_yu": {
286+
"name": "wdim",
287+
"coords": data.variables["LayCoord_w"][:],
288+
},
289+
}
290+
bottom_depth = np.ma.getdata(data.variables["waterdepth"][time_index, :], False)
291+
waterlevel = np.ma.getdata(data.variables["s1"][time_index, :], False)
292+
coords = str(data.variables["waterdepth"].coordinates).split()
293+
else:
294+
cords_to_layers = {
295+
"FlowElem_xcc FlowElem_ycc LayCoord_cc LayCoord_cc": {
296+
"name": "laydim",
297+
"coords": data.variables["LayCoord_cc"][:],
298+
},
299+
"FlowLink_xu FlowLink_yu": {
300+
"name": "wdim",
301+
"coords": data.variables["LayCoord_w"][:],
302+
},
303+
}
304+
bottom_depth = np.ma.getdata(data.variables["waterdepth"][time_index, :], False)
305+
waterlevel = np.ma.getdata(data.variables["s1"][time_index, :], False)
306+
coords = str(data.variables["waterdepth"].coordinates).split()
307+
308+
layer_dim = str(data.variables[variable].coordinates)
302309

303-
layer_dim = str(data.variables[variable].coordinates)
310+
elif isinstance(data, xr.Dataset):
311+
if "mesh2d" in variable:
312+
cords_to_layers = {
313+
"mesh2d_face_x mesh2d_face_y": {
314+
"name": "mesh2d_nLayers",
315+
"coords": data.variables["mesh2d_layer_sigma"][:],
316+
},
317+
"mesh2d_edge_x mesh2d_edge_y": {
318+
"name": "mesh2d_nInterfaces",
319+
"coords": data.variables["mesh2d_interface_sigma"][:],
320+
},
321+
}
322+
bottom_depth = data["mesh2d_waterdepth"].values[time_index, :]
323+
waterlevel = data["mesh2d_s1"].values[time_index, :]
324+
coords = list(data["waterdepth"].coords)
325+
elif str(list(data[variable].coords)) == "['FlowElem_xcc', 'FlowElem_ycc', 'time']":
326+
cords_to_layers = {
327+
"FlowElem_xcc FlowElem_ycc": {
328+
"name": "laydim",
329+
"coords": data.variables["LayCoord_cc"][:],
330+
},
331+
"FlowLink_xu FlowLink_yu": {
332+
"name": "wdim",
333+
"coords": data.variables["LayCoord_w"][:],
334+
},
335+
}
336+
bottom_depth = data["waterdepth"].values[time_index, :]
337+
waterlevel = data["s1"].values[time_index, :]
338+
coords = list(data["waterdepth"].coords)
339+
340+
layer_dim = " ".join(map(str, list(data[variable].coords)[0:2]))
304341

305342
try:
306343
cord_sys = cords_to_layers[layer_dim]["coords"]
307344
except KeyError as exc:
308345
raise ValueError("Coordinates not recognized.") from exc
309-
layer_percentages = np.ma.getdata(cord_sys, False) # accumulative
346+
347+
if isinstance(data, netCDF4.Dataset):
348+
layer_percentages = np.ma.getdata(cord_sys, False) # accumulative
349+
elif isinstance(data, xr.Dataset):
350+
layer_percentages= cord_sys.values # accumulative
351+
310352

311353
if layer_dim == "FlowLink_xu FlowLink_yu":
312354
# interpolate
@@ -348,7 +390,11 @@ def get_layer_data(
348390
z = [bottom_depth * layer_percentages[layer_index]]
349391
waterdepth = np.append(waterdepth, z)
350392

351-
time = np.ma.getdata(data.variables["time"][time_index], False) * np.ones(len(x))
393+
if isinstance(data, netCDF4.Dataset):
394+
time = np.ma.getdata(data.variables["time"][time_index], False) * np.ones(len(x))
395+
elif isinstance(data, xr.Dataset):
396+
time= [data.time.values[time_index]] * len(x)
397+
352398

353399
index = np.arange(0, len(time))
354400
layer_data = xr.Dataset(

0 commit comments

Comments
 (0)