Skip to content

Commit e20a06f

Browse files
authored
Lint hindcast (#398)
Adds type hints and enforces linting on the wave hindcast module.
1 parent cde337f commit e20a06f

File tree

4 files changed

+250
-262
lines changed

4 files changed

+250
-262
lines changed

.github/workflows/pylint.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ jobs:
4444
- name: Run Pylint on mhkit/river/
4545
run: |
4646
pylint --extension-pkg-allow-list=netCDF4 mhkit/river/
47+
48+
- name: Run Pylint on mhkit/wave/io/hindcast/
49+
run: |
50+
pylint mhkit/wave/io/hindcast/

mhkit/wave/io/hindcast/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""Wave hindcast data import and processing module.
2+
3+
This module provides functionality for importing and processing wave hindcast data,
4+
including wind toolkit data and WPTO hindcast data. The hindcast io module is
5+
separated from the general io module to allow for more efficient handling of
6+
CI tests.
7+
"""
8+
19
from mhkit.wave.io.hindcast import wind_toolkit
210

311
try:
@@ -8,4 +16,3 @@
816
"MHKiT-Python. If you are using Windows and calling from"
917
"MHKiT-MATLAB this is expected."
1018
)
11-
pass

mhkit/wave/io/hindcast/hindcast.py

Lines changed: 116 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,14 @@
55
regions, request point data for various parameters, and request directional
66
spectrum data.
77
8-
Functions:
9-
- region_selection(lat_lon): Returns the name of the predefined region for
10-
given latitude and longitude coordinates.
11-
- request_wpto_point_data(data_type, parameter, lat_lon, years, tree=None,
12-
unscale=True, str_decode=True, hsds=True): Returns data from the WPTO wave
13-
hindcast hosted on AWS at the specified latitude and longitude point(s) for
14-
the requested data type, parameter, and years.
15-
- request_wpto_directional_spectrum(lat_lon, year, tree=None, unscale=True,
16-
str_decode=True, hsds=True): Returns directional spectra data from the WPTO
17-
wave hindcast hosted on AWS at the specified latitude and longitude point(s)
18-
for the given year.
19-
20-
Dependencies:
21-
- sys
22-
- time.sleep
23-
- pandas
24-
- xarray
25-
- numpy
26-
- rex.MultiYearWaveX, rex.WaveX
27-
288
Author: rpauly, aidanbharath, ssolson
299
Date: 2023-09-26
3010
"""
3111

3212
import os
3313
import sys
3414
from time import sleep
15+
from typing import List, Tuple, Union, Optional, Dict
3516
import pandas as pd
3617
import xarray as xr
3718
import numpy as np
@@ -40,7 +21,7 @@
4021
from mhkit.utils.type_handling import convert_to_dataset
4122

4223

43-
def region_selection(lat_lon):
24+
def region_selection(lat_lon: Union[List[float], Tuple[float, float]]) -> str:
4425
"""
4526
Returns the name of the predefined region in which the given
4627
coordinates reside. Can be used to check if the passed lat/lon
@@ -64,13 +45,17 @@ def region_selection(lat_lon):
6445
f"lat_lon values must be of type float or int. Got: {type(lat_lon[0])}"
6546
)
6647

67-
regions = {
48+
regions: Dict[str, Dict[str, List[float]]] = {
6849
"Hawaii": {"lat": [15.0, 27.000002], "lon": [-164.0, -151.0]},
6950
"West_Coast": {"lat": [30.0906, 48.8641], "lon": [-130.072, -116.899]},
7051
"Atlantic": {"lat": [24.382, 44.8247], "lon": [-81.552, -65.721]},
7152
}
7253

73-
def region_search(lat_lon, region, regions):
54+
def region_search(
55+
lat_lon: Union[List[float], Tuple[float, float]],
56+
region: str,
57+
regions: Dict[str, Dict[str, List[float]]],
58+
) -> bool:
7459
return all(
7560
regions[region][dk][0] <= d <= regions[region][dk][1]
7661
for dk, d in {"lat": lat_lon[0], "lon": lat_lon[1]}.items()
@@ -84,18 +69,23 @@ def region_search(lat_lon, region, regions):
8469
return region[0]
8570

8671

72+
# pylint: disable=too-many-arguments
73+
# pylint: disable=too-many-positional-arguments
74+
# pylint: disable=too-many-locals
75+
# pylint: disable=too-many-branches
76+
# pylint: disable=too-many-statements
8777
def request_wpto_point_data(
88-
data_type,
89-
parameter,
90-
lat_lon,
91-
years,
92-
tree=None,
93-
unscale=True,
94-
str_decode=True,
95-
hsds=True,
96-
path=None,
97-
to_pandas=True,
98-
):
78+
data_type: str,
79+
parameter: Union[str, List[str]],
80+
lat_lon: Union[Tuple[float, float], List[Tuple[float, float]]],
81+
years: List[int],
82+
tree: Optional[str] = None,
83+
unscale: bool = True,
84+
str_decode: bool = True,
85+
hsds: bool = True,
86+
path: Optional[str] = None,
87+
to_pandas: bool = True,
88+
) -> Tuple[Union[pd.DataFrame, xr.Dataset], pd.DataFrame]:
9989
"""
10090
Returns data from the WPTO wave hindcast hosted on AWS at the
10191
specified latitude and longitude point(s), or the closest
@@ -190,7 +180,10 @@ def request_wpto_point_data(
190180

191181
# Attempt to load data from cache
192182
# Construct a string representation of the function parameters
193-
hash_params = f"{data_type}_{parameter}_{lat_lon}_{years}_{tree}_{unscale}_{str_decode}_{hsds}_{path}_{to_pandas}"
183+
hash_params = (
184+
f"{data_type}_{parameter}_{lat_lon}_{years}_{tree}_{unscale}_"
185+
f"{str_decode}_{hsds}_{path}_{to_pandas}"
186+
)
194187
cache_dir = _get_cache_dir()
195188
data, meta, _ = handle_caching(
196189
hash_params,
@@ -200,105 +193,105 @@ def request_wpto_point_data(
200193

201194
if data is not None:
202195
return data, meta
203-
else:
204-
if "directional_wave_spectrum" in parameter:
205-
sys.exit("This function does not support directional_wave_spectrum output")
206196

207-
# Check for multiple region selection
208-
if isinstance(lat_lon[0], float):
209-
region = region_selection(lat_lon)
210-
else:
211-
region_list = []
212-
for loc in lat_lon:
213-
region_list.append(region_selection(loc))
214-
if region_list.count(region_list[0]) == len(lat_lon):
215-
region = region_list[0]
216-
else:
217-
sys.exit("Coordinates must be within the same region!")
218-
219-
if path:
220-
wave_path = path
221-
elif data_type == "3-hour":
222-
wave_path = f"/nrel/US_wave/{region}/{region}_wave_*.h5"
223-
elif data_type == "1-hour":
224-
wave_path = (
225-
f"/nrel/US_wave/virtual_buoy/{region}/{region}_virtual_buoy_*.h5"
226-
)
227-
else:
228-
print("ERROR: invalid data_type")
229-
230-
wave_kwargs = {
231-
"tree": tree,
232-
"unscale": unscale,
233-
"str_decode": str_decode,
234-
"hsds": hsds,
235-
"years": years,
236-
}
237-
data_list = []
238-
239-
with MultiYearWaveX(wave_path, **wave_kwargs) as rex_waves:
240-
if isinstance(parameter, list):
241-
for param in parameter:
242-
temp_data = rex_waves.get_lat_lon_df(param, lat_lon)
243-
gid = rex_waves.lat_lon_gid(lat_lon)
244-
cols = temp_data.columns[:]
245-
for i, col in zip(range(len(cols)), cols):
246-
temp = f"{param}_{i}"
247-
temp_data = temp_data.rename(columns={col: temp})
197+
if "directional_wave_spectrum" in parameter:
198+
sys.exit("This function does not support directional_wave_spectrum output")
248199

249-
data_list.append(temp_data)
250-
data = pd.concat(data_list, axis=1)
200+
# Check for multiple region selection
201+
if isinstance(lat_lon[0], float):
202+
region = region_selection(lat_lon)
203+
else:
204+
region_list = []
205+
for loc in lat_lon:
206+
region_list.append(region_selection(loc))
207+
if region_list.count(region_list[0]) == len(lat_lon):
208+
region = region_list[0]
209+
else:
210+
sys.exit("Coordinates must be within the same region!")
251211

252-
else:
253-
data = rex_waves.get_lat_lon_df(parameter, lat_lon)
254-
cols = data.columns[:]
212+
if path:
213+
wave_path = path
214+
elif data_type == "3-hour":
215+
wave_path = f"/nrel/US_wave/{region}/{region}_wave_*.h5"
216+
elif data_type == "1-hour":
217+
wave_path = f"/nrel/US_wave/virtual_buoy/{region}/{region}_virtual_buoy_*.h5"
218+
else:
219+
raise ValueError(
220+
f"Invalid data_type: {data_type}. Must be '3-hour' or '1-hour'"
221+
)
255222

223+
wave_kwargs = {
224+
"tree": tree,
225+
"unscale": unscale,
226+
"str_decode": str_decode,
227+
"hsds": hsds,
228+
"years": years,
229+
}
230+
data_list = []
231+
232+
with MultiYearWaveX(wave_path, **wave_kwargs) as rex_waves:
233+
if isinstance(parameter, list):
234+
for param in parameter:
235+
temp_data = rex_waves.get_lat_lon_df(param, lat_lon)
236+
gid = rex_waves.lat_lon_gid(lat_lon)
237+
cols = temp_data.columns[:]
256238
for i, col in zip(range(len(cols)), cols):
257-
temp = f"{parameter}_{i}"
258-
data = data.rename(columns={col: temp})
239+
temp = f"{param}_{i}"
240+
temp_data = temp_data.rename(columns={col: temp})
259241

260-
meta = rex_waves.meta.loc[cols, :]
261-
meta = meta.reset_index(drop=True)
262-
gid = rex_waves.lat_lon_gid(lat_lon)
263-
meta["gid"] = gid
242+
data_list.append(temp_data)
243+
data = pd.concat(data_list, axis=1)
264244

265-
if not to_pandas:
266-
data = convert_to_dataset(data)
267-
data["time_index"] = pd.to_datetime(data.time_index)
245+
else:
246+
data = rex_waves.get_lat_lon_df(parameter, lat_lon)
247+
cols = data.columns[:]
268248

269-
if isinstance(parameter, list):
270-
param_coords = [f"{param}_{i}" for param in parameter]
271-
data.coords["parameter"] = xr.DataArray(
272-
param_coords, dims="parameter"
273-
)
249+
for i, col in zip(range(len(cols)), cols):
250+
temp = f"{parameter}_{i}"
251+
data = data.rename(columns={col: temp})
274252

275-
data.coords["year"] = xr.DataArray(years, dims="year")
253+
meta = rex_waves.meta.loc[cols, :]
254+
meta = meta.reset_index(drop=True)
255+
gid = rex_waves.lat_lon_gid(lat_lon)
256+
meta["gid"] = gid
276257

277-
meta_ds = meta.to_xarray()
278-
data = xr.merge([data, meta_ds])
258+
if not to_pandas:
259+
data = convert_to_dataset(data)
260+
data["time_index"] = pd.to_datetime(data.time_index)
279261

280-
# Remove the 'index' coordinate
281-
data = data.drop_vars("index")
262+
if isinstance(parameter, list):
263+
param_coords = [f"{param}_{i}" for param in parameter]
264+
data.coords["parameter"] = xr.DataArray(param_coords, dims="parameter")
282265

283-
# save_to_cache(hash_params, data, meta)
284-
handle_caching(
285-
hash_params,
286-
cache_dir,
287-
cache_content={"data": data, "metadata": meta, "write_json": None},
288-
)
266+
data.coords["year"] = xr.DataArray(years, dims="year")
289267

290-
return data, meta
268+
meta_ds = meta.to_xarray()
269+
data = xr.merge([data, meta_ds])
270+
271+
# Remove the 'index' coordinate
272+
data = data.drop_vars("index")
273+
274+
# save_to_cache(hash_params, data, meta)
275+
handle_caching(
276+
hash_params,
277+
cache_dir,
278+
cache_content={"data": data, "metadata": meta, "write_json": None},
279+
)
280+
281+
return data, meta
291282

292283

284+
# pylint: disable=too-many-branches
285+
# pylint: disable=too-many-statements
293286
def request_wpto_directional_spectrum(
294-
lat_lon,
295-
year,
296-
tree=None,
297-
unscale=True,
298-
str_decode=True,
299-
hsds=True,
300-
path=None,
301-
):
287+
lat_lon: Union[Tuple[float, float], List[Tuple[float, float]]],
288+
year: str,
289+
tree: Optional[str] = None,
290+
unscale: bool = True,
291+
str_decode: bool = True,
292+
hsds: bool = True,
293+
path: Optional[str] = None,
294+
) -> Tuple[xr.Dataset, pd.DataFrame]:
302295
"""
303296
Returns directional spectra data from the WPTO wave hindcast hosted
304297
on AWS at the specified latitude and longitude point(s),
@@ -417,10 +410,10 @@ def request_wpto_directional_spectrum(
417410
)
418411

419412
# Create bins for multiple smaller API dataset requests
420-
N = 6
413+
num_bins = 6
421414
length = len(rex_waves)
422-
quotient, remainder = divmod(length, N)
423-
bins = [i * quotient for i in range(N + 1)]
415+
quotient, remainder = divmod(length, num_bins)
416+
bins = [i * quotient for i in range(num_bins + 1)]
424417
bins[-1] += remainder
425418
index_bins = (np.array(bins) * len(frequency) * len(direction)).tolist()
426419

@@ -436,7 +429,7 @@ def request_wpto_directional_spectrum(
436429
try:
437430
data_array = rex_waves[parameter, bins[i] : bins[i + 1], :, :, gid]
438431
str_error = None
439-
except Exception as err:
432+
except OSError as err:
440433
str_error = str(err)
441434

442435
if str_error:
@@ -501,7 +494,7 @@ def request_wpto_directional_spectrum(
501494
return data, meta
502495

503496

504-
def _get_cache_dir():
497+
def _get_cache_dir() -> str:
505498
"""
506499
Returns the path to the cache directory.
507500
"""

0 commit comments

Comments
 (0)