diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 5955e99c..5e53787c 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -1,7 +1,11 @@ # SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences +# SPDX-FileContributor: Bernhard Haas +# SPDX-FileContributor: Sahil Jhawar # # SPDX-License-Identifier: Apache-2.0 +"""Combined RBM Dataset class supporting .mat, .pickle, and .nc file formats.""" + from __future__ import annotations import datetime as dt @@ -11,6 +15,7 @@ from typing import Any, Literal import distance +import netCDF4 import numpy as np from dateutil.relativedelta import relativedelta from numpy.typing import NDArray @@ -32,6 +37,7 @@ VariableEnum, VariableLiteral, ) +from swvo.io.RBMDataSet.custom_enums import MfmEnumLiteral from swvo.io.RBMDataSet.utils import ( get_file_path_any_format, join_var, @@ -41,10 +47,47 @@ from swvo.io.utils import enforce_utc_timezone +def _read_all_datasets_netcdf(file_path: str | Path) -> dict[str, Any]: + """Reads all datasets (variables) from a NetCDF file, including those in groups. + + This function recursively traverses all groups and variables in a NetCDF-4 + file and stores their data in a dictionary. The key for each dataset is its + full hierarchical path. + + Args: + file_path (str | Path): The path to the NetCDF file. + + Returns: + Dict[str, Any]: A dictionary where keys are the full variable paths + and values are the corresponding NumPy arrays. + """ + datasets: dict[str, Any] = {} + file_path = Path(file_path) + + def _read_all_recursively(group: netCDF4.Group | netCDF4.Dataset, path: str = ""): + for var_name, var_obj in group.variables.items(): + full_path = f"{path}/{var_name}" if path else var_name + datasets[full_path] = var_obj[:] + + for group_name, group_obj in group.groups.items(): + new_path = f"{path}/{group_name}" if path else group_name + _read_all_recursively(group_obj, new_path) + + if not file_path.exists(): + print(f"File not found: {file_path}") + return {} + + with netCDF4.Dataset(file_path, "r") as nc_file: + _read_all_recursively(nc_file) + + return datasets + + class RBMDataSet: - """RBMDataSet class for loading and managing data. + """RBMDataSet class supporting .mat, .pickle, and .nc file formats. - This class can load data either from files or from a dictionary. + This unified class handles loading RBM (Radiation Belt Model) data from multiple + file formats. It can load data either from files or from a dictionary. For file-based loading, provide `start_time`, `end_time`, and `folder_path`. For dictionary-based loading, initialize without these parameters and use `update_from_dict()`. @@ -63,7 +106,7 @@ class RBMDataSet: End time for file-based loading. folder_path : Path, optional Base folder path for file-based loading. - preferred_extension : Literal["mat", "pickle"], optional + preferred_extension : Literal["mat", "pickle", "nc"], optional Preferred file extension for file-based loading. Default is "pickle". verbose : bool, optional Whether to print verbose output. Default is True. @@ -128,12 +171,13 @@ def __init__( start_time: dt.datetime | None = None, end_time: dt.datetime | None = None, folder_path: Path | None = None, - preferred_extension: Literal["mat", "pickle"] = "pickle", + preferred_extension: Literal["mat", "pickle", "nc"] = "pickle", *, verbose: bool = True, enable_dict_loading: bool = False, ) -> None: self.possible_variables: list[str] = list(VariableLiteral.__args__) + # Handle satellite conversion with special cases for GOES if isinstance(satellite, str): if satellite.lower() == "goesprimary": @@ -149,11 +193,17 @@ def __init__( if isinstance(mfm, str): mfm = MfmEnum[mfm.upper()] + # Validate preferred_extension + if preferred_extension not in ("mat", "pickle", "nc"): + msg = f"preferred_extension must be 'mat', 'pickle', or 'nc', got '{preferred_extension}'" + raise ValueError(msg) + # Store the original satellite enum for properties and other attributes self._satellite = satellite self._instrument = instrument self._mfm = mfm self._verbose = verbose + self._preferred_ext = preferred_extension # For dict-based loading, modify satellite properties if start_time is None and end_time is None and folder_path is None: @@ -171,7 +221,6 @@ def __init__( self._end_time = end_time self._satellite = satellite self._folder_path = Path(folder_path) - self._preferred_ext = preferred_extension self._folder_type = self._satellite.folder_type self._file_path_stem = self._create_file_path_stem() self._file_name_stem = self._create_file_name_stem() @@ -179,6 +228,7 @@ def __init__( self._date_of_files = self._create_date_list() self._file_loading_mode = True self._enable_dict_loading = enable_dict_loading + self._netcdf_dataset_cache: dict[Path, dict[str, Any]] = {} def __repr__(self) -> str: return f"{self.__class__.__name__}({self._satellite}, {self._instrument}, {self._mfm})" @@ -197,14 +247,18 @@ def __getattr__(self, name: str) -> NDArray[np.float64]: # Handle computed properties for both modes if name == "P": if len(self.MLT) == 0: # MLT not found - return np.asarray([]) - return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) + self.P = np.asarray([]) + else: + self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) + return self.P if name == "InvV": if len(self.InvK) == 0 or len(self.InvMu) == 0: # invariants not found - return np.asarray([]) - inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) - return self.InvMu * (inv_K_repeated + 0.5) ** 2 + self.InvV = np.asarray([]) + else: + inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) + self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 + return self.InvV # check if a sat variable is requested # if we find a similar word, suggest that to the user @@ -289,7 +343,7 @@ def update_from_dict( VariableNotFoundError If a key in the `source_dict` is not a valid `VariableLiteral`. RuntimeError - If the `RBMDataSet` is in file loading mode and dictionary loading is not enabled. + If the RBMDataSet is in file loading mode and dictionary loading is not enabled. """ if self._file_loading_mode and not self._enable_dict_loading: @@ -328,11 +382,14 @@ def _create_date_list(self) -> list[dt.datetime]: return list(date_of_files) def _create_file_path_stem(self) -> Path: - # implement special cases here - # if self._satellite == SatelliteEnum.THEMIS: - # pass + """Create the file path stem based on format and folder type.""" if self._folder_type == FolderTypeEnum.DataServer: - return self._folder_path / self._satellite.mission / self._satellite.sat_name / "Processed_Mat_Files" + if self._preferred_ext == "nc": + # NetCDF files use a different path structure + return self._folder_path / self._satellite.mission / self._satellite.sat_name + else: + # .mat and .pickle files use Processed_Mat_Files subdirectory + return self._folder_path / self._satellite.mission / self._satellite.sat_name / "Processed_Mat_Files" if self._folder_type == FolderTypeEnum.SingleFolder: return self._folder_path @@ -341,10 +398,7 @@ def _create_file_path_stem(self) -> Path: raise ValueError(msg) def _create_file_name_stem(self) -> str: - # implement special cases here - # if self._satellite == SatelliteEnum.THEMIS: - # pass - + """Create the file name stem.""" return self._satellite.sat_name + "_" + self._instrument.instrument_name + "_" def get_satellite_name(self) -> str: @@ -370,13 +424,20 @@ def get_print_name(self) -> str: return self._satellite.sat_name + " " + self._instrument.instrument_name def _load_variable(self, var: Variable | VariableEnum) -> None: + """Load variable using format-specific loading logic.""" + if self._preferred_ext == "nc": + self._load_variable_netcdf(var) + else: + self._load_variable_mat_pickle(var) + + def _load_variable_mat_pickle(self, var: Variable | VariableEnum) -> None: + """Load variable from .mat or .pickle files.""" loaded_var_arrs: dict[str, NDArray[np.number]] = {} var_names_storred: list[str] = [] # computed values if isinstance(var, VariableEnum) and var == VariableEnum.INV_V: inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) - self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 return @@ -421,10 +482,6 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: correct_time_idx = (datetimes >= self._start_time) & (datetimes <= self._end_time) for key in file_content: - # if key == 'time' and var not in [VariableEnum.Time, VariableEnum.DateTime]: - # only save time if directly requested - # continue - var_arr = file_content[key] if ((not isinstance(var_arr, np.ndarray)) or (not np.issubdtype(var_arr.dtype, np.number))) and ( key != "datetime" @@ -456,11 +513,144 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: setattr(self, var_name, loaded_var_arrs[var_name]) + def _load_variable_netcdf(self, var: Variable | VariableEnum) -> None: + """Load variable from NetCDF files.""" + loaded_var_arrs: dict[str, NDArray[np.number]] = {} + var_names_stored: list[str] = [] + + # computed values + if isinstance(var, VariableEnum) and var == VariableEnum.INV_V: + inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) + self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 + return + + if isinstance(var, VariableEnum) and var == VariableEnum.P: + self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) + return + + for date in self._date_of_files: + if self._folder_type == FolderTypeEnum.DataServer: + start_month = date.replace(day=1) + next_month = start_month + relativedelta(months=1, days=-1) + date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d") + + file_name = self._file_name_stem + date_str + "_" + self._mfm.mfm_name + ".nc" + else: + raise NotImplementedError + + file_path = self._file_path_stem / file_name + datasets = self._get_cached_datasets_netcdf(file_path) + + if datasets == {}: + continue + + # also store python datetimes for binning + datetimes = typing.cast( + NDArray[np.object_], + np.asarray( + [dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in datasets["time"]] + ), + ) + datasets["datetime"] = datetimes + + # limit in time + correct_time_idx = (datetimes >= self._start_time) & (datetimes <= self._end_time) + + for key, var_arr in datasets.items(): + if ((not isinstance(var_arr, np.ndarray)) or (not np.issubdtype(var_arr.dtype, np.number))) and ( + key != "datetime" + ): + # var represents some strings or metadata objects; don't read them + continue + var_arr = typing.cast("NDArray[np.number]", var_arr) + + # check if var is time dependent + if var_arr.shape[0] == correct_time_idx.shape[0]: + var_arr_trimmed = var_arr[correct_time_idx.reshape(-1), ...] + + joined_value = ( + join_var(loaded_var_arrs[key], var_arr_trimmed) if key in loaded_var_arrs else var_arr_trimmed + ) + else: + joined_value = var_arr + + loaded_var_arrs[key] = joined_value # ty:ignore[invalid-assignment] + + if key not in var_names_stored: + var_names_stored.append(key) + + # not a single file was found + if var.var_name not in var_names_stored: + setattr(self, var.var_name, np.asarray([])) + + for var_name in var_names_stored: + if var_name == "datetime": + loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # ty:ignore[invalid-assignment] + + rbm_var_names = self._get_rbm_name_for_nc(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type] + + if rbm_var_names is not None: + if isinstance(rbm_var_names, list): + for name in rbm_var_names: + setattr(self, name, loaded_var_arrs[var_name]) + else: + setattr(self, rbm_var_names, loaded_var_arrs[var_name]) + + def _get_cached_datasets_netcdf(self, file_path: Path) -> dict[str, Any]: + """Return cached parsed NetCDF content for a monthly file.""" + file_path = Path(file_path) + if file_path not in self._netcdf_dataset_cache: + self._netcdf_dataset_cache[file_path] = _read_all_datasets_netcdf(file_path) + return self._netcdf_dataset_cache[file_path] + + @classmethod + def _get_rbm_name_for_nc( + cls, var_name: str, mag_field: MfmEnumLiteral + ) -> VariableLiteral | None | list[VariableLiteral]: + """Map NetCDF variable names to RBM variable names.""" + match var_name: + case "time": + return "time" + case "datetime": + return "datetime" + case "flux/FEDU": + return ["Flux", "FEDU"] + case "flux/FEIU": + return ["Flux", "FEIU"] + case "flux/alpha_eq": + return "alpha_eq_model" + case "flux/energy": + return "energy_channels" + case "flux/alpha_local": + return "alpha_local" + case "position/xGEO": + return "xGEO" + case _ if var_name == f"position/{mag_field}/MLT": + return "MLT" + case _ if var_name == f"position/{mag_field}/R0": + return "R0" + case _ if var_name == f"position/{mag_field}/Lstar": + return "Lstar" + case _ if var_name == f"position/{mag_field}/Lm": + return "Lm" + case _ if var_name == f"mag_field/{mag_field}/B_local": + return "B_total" + case "psd/PSD": + return "PSD" + case _ if var_name == f"psd/{mag_field}/inv_mu": + return "InvMu" + case _ if var_name == f"psd/{mag_field}/inv_K": + return "InvK" + case "density/density_local": + return "density" + case _: + return None + def get_loaded_variables(self) -> list[str]: """Get a list of currently loaded variable names.""" loaded_vars = [] for var in VariableEnum: - if hasattr(self, var.var_name): + if var.var_name in self.__dict__: loaded_vars.append(var.var_name) return loaded_vars diff --git a/swvo/io/RBMDataSet/RBMDataSetManager.py b/swvo/io/RBMDataSet/RBMDataSetManager.py deleted file mode 100644 index cdb1301f..00000000 --- a/swvo/io/RBMDataSet/RBMDataSetManager.py +++ /dev/null @@ -1,169 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from datetime import datetime -from pathlib import Path -from typing import Iterable, Literal, overload - -from swvo.io.RBMDataSet.custom_enums import ( - FolderTypeEnum, - InstrumentEnum, - MfmEnum, - Satellite, - SatelliteEnum, - SatelliteLike, -) -from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet - -RBMDataSetHash = tuple[ - datetime, - datetime, - Path, - Satellite | SatelliteEnum, - InstrumentEnum, - MfmEnum, - FolderTypeEnum, -] - - -class RBMDataSetManager: - """ - RBMDataSetManager class for managing RBMDataSet instances. - - Notes - ----- - Use the `load` class method to create and retrieve datasets. Direct instantiation is not allowed. - - Raises - ------ - RuntimeError - If the constructor is called directly instead of using the `load` method. - """ - - _instance = None - data_set_dict: dict[RBMDataSetHash, RBMDataSet] - - def __init__(self) -> None: - msg = "Call load() instead!" - raise RuntimeError(msg) - - @overload - @classmethod - def load( - cls, - start_time: datetime, - end_time: datetime, - folder_path: Path, - satellite: SatelliteLike, - instrument: InstrumentEnum, - mfm: MfmEnum, - folder_type: FolderTypeEnum = FolderTypeEnum.DataServer, - *, - verbose: bool = True, - preferred_extension: str = "pickle", - ) -> RBMDataSet: ... - - @overload - @classmethod - def load( - cls, - start_time: datetime, - end_time: datetime, - folder_path: Path, - satellite: Iterable[SatelliteLike], - instrument: InstrumentEnum, - mfm: MfmEnum, - folder_type: FolderTypeEnum = FolderTypeEnum.DataServer, - *, - verbose: bool = True, - preferred_extension: str = "pickle", - ) -> list[RBMDataSet]: ... - - @classmethod - def load( - cls, - start_time: datetime, - end_time: datetime, - folder_path: Path, - satellite: SatelliteLike | Iterable[SatelliteLike], - instrument: InstrumentEnum, - mfm: MfmEnum, - folder_type: FolderTypeEnum = FolderTypeEnum.DataServer, - *, - verbose: bool = True, - preferred_extension: Literal["mat", "pickle"] = "pickle", - ) -> RBMDataSet | list[RBMDataSet]: - """Loads an RBMDataSet or a list of RBMDataSets based on the provided parameters. - - Parameters - ---------- - start_time : datetime - Start time of the data set. - end_time : datetime - End time of the data set. - folder_path : Path - Path to the folder where the data set is stored. - satellite : :class:`SatelliteLike` | Iterable[:class:`SatelliteLike`] - Satellite identifier(s) as enum or string. If a single satellite is provided, it can be a string or an enum. - instrument : :class:`InstrumentEnum` - Instrument enumeration, e.g., :class:`InstrumentEnum.HOPE`. - mfm : :class:`MfmEnum` - Magnetic field model enum, e.g., :class:`MfmEnum.T89`. - folder_type : :class:`FolderTypeEnum`, optional - Type of folder where the data is stored, by default :class:`FolderTypeEnum.DataServer`. - verbose : bool, optional - Whether to print verbose output, by default True. - preferred_extension : str, optional - Preferred file extension for the data set to be loaded, by default "pickle". - - Returns - ------- - Union[:class:`RBMDataSet`, list[:class:`RBMDataSet`]] - An instance of RBMDataSet or a list of RBMDataSet instances, depending on the input parameters. - Variables are lazily loaded from the file system when accessed. - """ - if cls._instance is None: - print("Initiating new RBMDataSetManager!") - cls._instance = cls.__new__(cls) - cls._instance.data_set_dict = {} - - if isinstance(satellite, str): - satellite = SatelliteEnum[satellite] - - if not isinstance(satellite, Iterable): - satellite = (satellite,) - - return_list: list[RBMDataSet] | RBMDataSet = [] - for sat in satellite: - key_tuple = ( - start_time, - end_time, - folder_path, - sat, - instrument, - mfm, - folder_type, - ) - - if key_tuple in cls._instance.data_set_dict: - return_list.append(cls._instance.data_set_dict[key_tuple]) - else: - cls._instance.data_set_dict[key_tuple] = RBMDataSet( - satellite=sat, # ty:ignore[invalid-argument-type] - instrument=instrument, - mfm=mfm, - start_time=start_time, - end_time=end_time, - folder_path=folder_path, - verbose=verbose, - preferred_extension=preferred_extension, - ) - return_list.append(cls._instance.data_set_dict[key_tuple]) - - if len(return_list) == 1: - return_list = return_list[0] - - return return_list diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py deleted file mode 100644 index 156d5e71..00000000 --- a/swvo/io/RBMDataSet/RBMNcDataSet.py +++ /dev/null @@ -1,268 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -import datetime as dt -import typing -from pathlib import Path -from typing import Any - -import netCDF4 -import numpy as np -from dateutil.relativedelta import relativedelta -from numpy.typing import NDArray - -from swvo.io.RBMDataSet import ( - RBMDataSet, -) -from swvo.io.RBMDataSet.custom_enums import ( - FolderTypeEnum, - InstrumentLike, - MfmEnumLiteral, - MfmLike, - SatelliteLike, - Variable, - VariableEnum, - VariableLiteral, -) -from swvo.io.RBMDataSet.utils import join_var - - -def _read_all_datasets_netcdf(file_path: str | Path) -> dict[str, Any]: - """Reads all datasets (variables) from a NetCDF file, including those in groups. - - This function recursively traverses all groups and variables in a NetCDF-4 - file and stores their data in a dictionary. The key for each dataset is its - full hierarchical path. - - Args: - file_path (str | Path): The path to the NetCDF file. - - Returns: - Dict[str, Any]: A dictionary where keys are the full variable paths - and values are the corresponding NumPy arrays. - """ - datasets: dict[str, Any] = {} - file_path = Path(file_path) - - def _read_all_recursively(group: netCDF4.Group | netCDF4.Dataset, path: str = ""): - for var_name, var_obj in group.variables.items(): - full_path = f"{path}/{var_name}" if path else var_name - datasets[full_path] = var_obj[:] - - for group_name, group_obj in group.groups.items(): - new_path = f"{path}/{group_name}" if path else group_name - _read_all_recursively(group_obj, new_path) - - if not file_path.exists(): - print(f"File not found: {file_path}") - return {} - - with netCDF4.Dataset(file_path, "r") as nc_file: - _read_all_recursively(nc_file) - - return datasets - - -class RBMNcDataSet(RBMDataSet): - """Class for handling RBM NetCDF data files.""" - - datetime: list[dt.datetime] - time: NDArray[np.float64] - energy_channels: NDArray[np.float64] - alpha_local: NDArray[np.float64] - alpha_eq_model: NDArray[np.float64] - alpha_eq_real: NDArray[np.float64] - InvMu: NDArray[np.float64] - InvMu_real: NDArray[np.float64] - InvK: NDArray[np.float64] - InvV: NDArray[np.float64] - Lstar: NDArray[np.float64] - Flux: NDArray[np.float64] - PSD: NDArray[np.float64] - MLT: NDArray[np.float64] - B_SM: NDArray[np.float64] - B_total: NDArray[np.float64] - B_sat: NDArray[np.float64] - xGEO: NDArray[np.float64] - P: NDArray[np.float64] - R0: NDArray[np.float64] - density: NDArray[np.float64] - - def __init__( - self, - start_time: dt.datetime, - end_time: dt.datetime, - folder_path: Path, - satellite: SatelliteLike, - instrument: InstrumentLike, - mfm: MfmLike, - *, - verbose: bool = True, - ) -> None: - super().__init__( - satellite=satellite, - instrument=instrument, - mfm=mfm, - start_time=start_time, - end_time=end_time, - folder_path=folder_path, - verbose=verbose, - ) - - mfm_str = mfm if isinstance(mfm, str) else mfm.mfm_name - - self.variable_lut = { - "time": "time", - "datetime": "datetime", - "flux/FEDU": "Flux", - "flux/alpha_eq": "alpha_eq_model", - "flux/energy": "energy_channels", - "flux/alpha_local": "alpha_local", - "position/xGEO": "xGEO", - "psd/PSD": "PSD", - "density/density_local": "density", - - f"position/{mfm_str}/MLT": "MLT", - f"position/{mfm_str}/R0": "R0", - f"position/{mfm_str}/Lstar": "Lstar", - f"position/{mfm_str}/Lm": "Lm", - f"mag_field/{mfm_str}/B_local": "B_total", - f"psd/{mfm_str}/inv_mu": "InvMu", - f"psd/{mfm_str}/inv_K": "InvK", - } - - def _create_file_path_stem(self) -> Path: - # implement special cases here - # if self._satellite == SatelliteEnum.THEMIS: - # pass - if self._folder_type == FolderTypeEnum.DataServer: - return self._folder_path / self._satellite.mission / self._satellite.sat_name - - if self._folder_type == FolderTypeEnum.SingleFolder: - return self._folder_path - - msg = "Encountered invalid FolderTypeEnum!" - raise ValueError(msg) - - def _load_variable(self, var: Variable | VariableEnum) -> None: - loaded_var_arrs: dict[str, NDArray[np.number]] = {} - var_names_stored: list[str] = [] - - # computed values - if isinstance(var, VariableEnum) and var == VariableEnum.INV_V: - inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) - - self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 - return - - if isinstance(var, VariableEnum) and var == VariableEnum.P: - self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) - return - - for date in self._date_of_files: - if self._folder_type == FolderTypeEnum.DataServer: - start_month = date.replace(day=1) - next_month = start_month + relativedelta(months=1, days=-1) - date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d") - - file_name = self._file_name_stem + date_str + "_" + self._mfm.mfm_name + ".nc" - else: - raise NotImplementedError - - datasets = _read_all_datasets_netcdf(self._file_path_stem / file_name) - - if datasets == {}: - continue - - # also store python datetimes for binning - datetimes = typing.cast( - NDArray[np.object_], - np.asarray( - [dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in datasets["time"]] - ), - ) - datasets["datetime"] = datetimes - - # limit in time - correct_time_idx = (datetimes >= self._start_time) & (datetimes <= self._end_time) - - for key, var_arr in datasets.items(): - if ((not isinstance(var_arr, np.ndarray)) or (not np.issubdtype(var_arr.dtype, np.number))) and ( - key != "datetime" - ): - # var represents some strings or metadata objects; don't read them - continue - var_arr = typing.cast("NDArray[np.number]", var_arr) - - # check if var is time dependent - if var_arr.shape[0] == correct_time_idx.shape[0]: - var_arr_trimmed = var_arr[correct_time_idx.reshape(-1), ...] - - joined_value = ( - join_var(loaded_var_arrs[key], var_arr_trimmed) if key in loaded_var_arrs else var_arr_trimmed - ) - else: - joined_value = var_arr - - loaded_var_arrs[key] = joined_value # ty:ignore[invalid-assignment] - - if key not in var_names_stored: - var_names_stored.append(key) - - # not a single file was found - if var.var_name not in var_names_stored: - setattr(self, var.var_name, np.asarray([])) - - for var_name in var_names_stored: - if var_name == "datetime": - loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # ty:ignore[invalid-assignment] - - rbm_var_names = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type] - - if rbm_var_names is not None: - if isinstance(rbm_var_names, list): - for name in rbm_var_names: - setattr(self, name, loaded_var_arrs[var_name]) - else: - setattr(self, rbm_var_names, loaded_var_arrs[var_name]) - - @classmethod - def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral | None | list[VariableLiteral]: - match var_name: - case "time": - return "time" - case "datetime": - return "datetime" - case "flux/FEDU": - return ["Flux", "FEDU"] - case "flux/FEIU": - return ["Flux", "FEIU"] - case "flux/alpha_eq": - return "alpha_eq_model" - case "flux/energy": - return "energy_channels" - case "flux/alpha_local": - return "alpha_local" - case "position/xGEO": - return "xGEO" - case _ if var_name == f"position/{mag_field}/MLT": - return "MLT" - case _ if var_name == f"position/{mag_field}/R0": - return "R0" - case _ if var_name == f"position/{mag_field}/Lstar": - return "Lstar" - case _ if var_name == f"position/{mag_field}/Lm": - return "Lm" - case _ if var_name == f"mag_field/{mag_field}/B_local": - return "B_total" - case "psd/PSD": - return "PSD" - case _ if var_name == f"psd/{mag_field}/inv_mu": - return "InvMu" - case _ if var_name == f"psd/{mag_field}/inv_K": - return "InvK" - case "density/density_local": - return "density" - case _: - return None diff --git a/swvo/io/RBMDataSet/__init__.py b/swvo/io/RBMDataSet/__init__.py index d51d0f87..1fa95486 100644 --- a/swvo/io/RBMDataSet/__init__.py +++ b/swvo/io/RBMDataSet/__init__.py @@ -20,8 +20,6 @@ SatelliteLiteral as SatelliteLiteral, VariableLiteral as VariableLiteral, ) -from swvo.io.RBMDataSet.RBMDataSetManager import RBMDataSetManager as RBMDataSetManager +from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet as RBMDataSet from swvo.io.RBMDataSet.interp_functions import TargetType as TargetType from swvo.io.RBMDataSet.scripts.create_RBSP_line_data import create_RBSP_line_data as create_RBSP_line_data -from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet as RBMDataSet -from swvo.io.RBMDataSet.RBMNcDataSet import RBMNcDataSet as RBMNcDataSet diff --git a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py index 411718d9..dedd3eb7 100644 --- a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py +++ b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py @@ -200,8 +200,26 @@ def _bin_in_space( grid_P_1d = None grid_R_1d = grid_R[0, :, 0, 0] - psd_binned = np.full((psd_in.shape[0], 1, grid_R.shape[1], psd_in.shape[1], psd_in.shape[2]), 0.0) - number_of_observations = np.full((psd_in.shape[0], 1, grid_R.shape[1], psd_in.shape[1], psd_in.shape[2]), 0) + psd_binned = np.full( + ( + psd_in.shape[0], + 1, + grid_R.shape[1], + psd_in.shape[1], + psd_in.shape[2], + ), + 0.0, + ) + number_of_observations = np.full( + ( + psd_in.shape[0], + 1, + grid_R.shape[1], + psd_in.shape[1], + psd_in.shape[2], + ), + 0, + ) for it in range(psd_in.shape[0]): if np.all(np.isnan(psd_in[it, :, :])): @@ -308,7 +326,14 @@ def _parallel_func_VK( V_finite = np.isfinite(V_data[it, :, K_idx_left]) V_sorted = 1 if np.all(np.diff(V_data[it, V_finite, K_idx_left]) >= 0) else -1 - V_idx_left_left = np.searchsorted(V_sorted * V_data[it, :, K_idx_left], V_sorted * V_val, side="right") - 1 + V_idx_left_left = ( + np.searchsorted( + V_sorted * V_data[it, :, K_idx_left], + V_sorted * V_val, + side="right", + ) + - 1 + ) V_idx_left_right = V_idx_left_left + 1 if V_idx_left_left == -1 or V_idx_left_right >= V_data.shape[1]: @@ -499,7 +524,12 @@ def plot_debug_figures( # plot satellite trajectory on PxR grid # [x_sat, y_sat] = pol2cart(self.P, self.R) - ax0.scatter(data_set.P[sat_time_idx], R_or_Lstar_arr[sat_time_idx], c="k", marker="D") + ax0.scatter( + data_set.P[sat_time_idx], + R_or_Lstar_arr[sat_time_idx], + c="k", + marker="D", + ) ax0.set_ylim(1, 6.6) ax0.set_title("Orbit") ax0.set_theta_offset(np.pi) # ty:ignore[unresolved-attribute] diff --git a/swvo/io/RBMDataSet/identify_orbits.py b/swvo/io/RBMDataSet/identify_orbits.py index c9db2e52..5e0b05cf 100644 --- a/swvo/io/RBMDataSet/identify_orbits.py +++ b/swvo/io/RBMDataSet/identify_orbits.py @@ -15,8 +15,7 @@ from scipy.interpolate import make_splrep from scipy.signal import find_peaks -if typing.TYPE_CHECKING: - from swvo.io.RBMDataSet import RBMDataSet, RBMNcDataSet +from swvo.io.RBMDataSet import RBMDataSet class Trajectory(NamedTuple): @@ -57,7 +56,7 @@ def _identify_orbits( def identify_orbits( - self: RBMDataSet | RBMNcDataSet, + self: RBMDataSet, orbit_type: Literal["R", "L*"] = "R", minimal_distance: int = 60, *, diff --git a/swvo/io/RBMDataSet/linearize_trajectories.py b/swvo/io/RBMDataSet/linearize_trajectories.py index fd3368b7..e47f95aa 100644 --- a/swvo/io/RBMDataSet/linearize_trajectories.py +++ b/swvo/io/RBMDataSet/linearize_trajectories.py @@ -6,16 +6,15 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Literal +from typing import Literal import numpy as np import pandas as pd from numpy.typing import NDArray from scipy.interpolate import interp1d -if TYPE_CHECKING: - from swvo.io.RBMDataSet import RBMDataSet, RBMNcDataSet - from swvo.io.RBMDataSet.identify_orbits import Trajectory +from swvo.io.RBMDataSet import RBMDataSet +from swvo.io.RBMDataSet.identify_orbits import Trajectory def _linearize_trajectories( @@ -79,7 +78,7 @@ def _linearize_trajectories( def linearize_trajectories( - self: RBMDataSet | RBMNcDataSet, + self: RBMDataSet, trajectories: list[Trajectory], orbit_type: Literal["R", "L*"] = "R", ) -> tuple[NDArray[np.floating], list[datetime]]: diff --git a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py index 7fa9f49d..cd7d4273 100644 --- a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py +++ b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py @@ -15,7 +15,6 @@ InstrumentEnum, MfmEnum, RBMDataSet, - RBMDataSetManager, SatelliteEnum, SatelliteLike, TargetType, @@ -107,15 +106,15 @@ def create_RBSP_line_data( for i, instrument in enumerate(instruments): rbm_data.append( - RBMDataSetManager.load( + RBMDataSet( + satellite, # ty: ignore[invalid-argument-type] + instrument, + mfm, start_time, end_time, data_server_path, - satellite, - instrument, - mfm, verbose=verbose, - ) # ty:ignore[no-matching-overload] + ) ) # strip of time dimention diff --git a/tests/io/RBMDataSet/data/ARASE/arase/arase_XEP_20260301to20260331_T89.nc b/tests/io/RBMDataSet/data/ARASE/arase/arase_XEP_20260301to20260331_T89.nc new file mode 100644 index 00000000..69cbde5f Binary files /dev/null and b/tests/io/RBMDataSet/data/ARASE/arase/arase_XEP_20260301to20260331_T89.nc differ diff --git a/tests/io/RBMDataSet/data/ARASE/arase/arase_XEP_20260301to20260331_T89.nc.license b/tests/io/RBMDataSet/data/ARASE/arase/arase_XEP_20260301to20260331_T89.nc.license new file mode 100644 index 00000000..54a4f5b0 --- /dev/null +++ b/tests/io/RBMDataSet/data/ARASE/arase/arase_XEP_20260301to20260331_T89.nc.license @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: 2026 GFZ Helmholtz Centre for Geosciences +SPDX-FileContributor: Sahil Jhawar + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 23253a66..fd81457f 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -402,6 +402,21 @@ def test_dict_mode_computed_invv_property(dict_dataset): np.testing.assert_allclose(dict_dataset.InvV, expected_invv) +def test_get_loaded_variables_includes_computed_variables(dict_dataset): + """Computed variables should be tracked once accessed.""" + dict_dataset.MLT = np.array([0.0, 6.0, 12.0]) + dict_dataset.InvMu = np.array([[0.1, 0.2]]) + dict_dataset.InvK = np.array([[1.0]]) + + _ = dict_dataset.P + _ = dict_dataset.InvV + + loaded_variables = dict_dataset.get_loaded_variables() + + assert "P" in loaded_variables + assert "InvV" in loaded_variables + + def test_dict_mode_getattr_errors(dict_dataset): """Test error handling for unset attributes in dict mode""" with pytest.raises(AttributeError, match="exists in `VariableLiteral` but has not been set"): @@ -835,3 +850,262 @@ def test_eq_different_types(): dataset2.time = [738000.0] assert dataset1 != dataset2 + + +@pytest.fixture +def mock_dataset_nc(mocker) -> RBMDataSet: + start_time = dt.datetime(2026, 3, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2026, 3, 31, tzinfo=timezone.utc) + + dataset = RBMDataSet( + start_time=start_time, + end_time=end_time, + folder_path=Path(__file__).parent / "./data/", + satellite=SatelliteEnum.ARASE, + instrument=InstrumentEnum.XEP, + mfm=MfmEnum.T89, + preferred_extension="nc", + verbose=True, + ) + + return dataset + + +def test_get_satellite_name_nc(mock_dataset_nc: RBMDataSet): + """Test get_satellite_name method.""" + assert mock_dataset_nc.get_satellite_name() == "arase" + + +def test_get_satellite_and_instrument_name_nc(mock_dataset_nc: RBMDataSet): + """Test get_satellite_and_instrument_name method.""" + assert mock_dataset_nc.get_satellite_and_instrument_name() == "arase_XEP" + + +def test_get_print_name_nc(mock_dataset_nc: RBMDataSet): + """Test get_print_name method.""" + assert mock_dataset_nc.get_print_name() == "arase XEP" + + +def test_getattr_with_valid_variable_nc(mock_dataset_nc: RBMDataSet): + """Test __getattr__ with a valid variable.""" + with mock.patch.object(mock_dataset_nc, "_load_variable") as _: + mock_dataset_nc.Flux = np.array([1.0, 2.0, 3.0]) + result = mock_dataset_nc.Flux + assert isinstance(result, np.ndarray) + assert (result == np.array([1.0, 2.0, 3.0])).all() + + +def test_getattr_with_invalid_variable_nc(mock_dataset_nc: RBMDataSet): + """Test __getattr__ with an invalid variable.""" + with pytest.raises(AttributeError): + _ = mock_dataset_nc.NonExistentAttribute + + +def test_getattr_with_similar_variable_nc(mock_dataset_nc: RBMDataSet): + """Test __getattr__ suggests similar variable name.""" + with pytest.raises(AttributeError) as e: + _ = mock_dataset_nc.Flx + + assert "Maybe you meant Flux?" in str(e.value) + + +def test_computed_invv_variable_nc(mock_dataset_nc: RBMDataSet): + """Test computed InvV variable.""" + + mock_dataset_nc.InvK = np.array([[1.0, 2.0]]) + mock_dataset_nc.InvMu = np.array([[0.1, 0.2], [0.3, 0.4]]) + + mock_dataset_nc._load_variable(VariableEnum.INV_V) + + expected = ( + mock_dataset_nc.InvMu + * (np.repeat(mock_dataset_nc.InvK[:, np.newaxis, :], mock_dataset_nc.InvMu.shape[1], axis=1) + 0.5) ** 2 + ) + np.testing.assert_array_equal(mock_dataset_nc.InvV, expected) + + +def test_computed_p_variable_nc(mock_dataset_nc: RBMDataSet): + """Test computed P variable.""" + + mock_dataset_nc.MLT = np.array([0.0, 6.0, 12.0, 18.0]) + + mock_dataset_nc._load_variable(VariableEnum.P) + + expected = ((mock_dataset_nc.MLT + 12) / 12 * np.pi) % (2 * np.pi) + np.testing.assert_array_equal(mock_dataset_nc.P, expected) + + +@pytest.mark.parametrize("satellite", list(SatelliteEnum)) +def test_all_satellites_work_nc(satellite, mock_module_string): + """Ensure all SatelliteEnum values initialize without error.""" + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset = RBMDataSet( + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), + satellite=satellite, + instrument=InstrumentEnum.HOPE, + mfm=MfmEnum.T89, + preferred_extension="nc", + ) + assert dataset._satellite == satellite + + +@pytest.mark.parametrize("instrument", list(InstrumentEnum)) +def test_all_instruments_work_nc(instrument, mock_module_string): + """Ensure all InstrumentEnum values initialize without error.""" + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset = RBMDataSet( + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), + satellite=SatelliteEnum.RBSPA, + instrument=instrument, + mfm=MfmEnum.T89, + preferred_extension="nc", + ) + assert dataset._instrument == instrument + + +def test_create_date_list_monthly_nc(mock_dataset_nc: RBMDataSet): + """Test monthly cadence date generation.""" + mock_dataset_nc.set_file_cadence(FileCadenceEnum.Monthly) + date_list = mock_dataset_nc._create_date_list() + assert date_list[0].month == 3 + assert all(date.tzinfo == timezone.utc for date in date_list) + + +def test_create_date_list_daily_nc(mock_dataset_nc: RBMDataSet): + """Test daily cadence date generation.""" + mock_dataset_nc.set_file_cadence(FileCadenceEnum.Daily) + date_list = mock_dataset_nc._create_date_list() + assert len(date_list) > 20 + assert all(date.tzinfo == timezone.utc for date in date_list) + + +def test_file_name_stem_generation_nc(mock_dataset_nc: RBMDataSet): + """Test that file name stem is generated correctly.""" + assert mock_dataset_nc._create_file_name_stem() == "arase_XEP_" + + +def test_file_path_stem_dataserver_nc(mock_dataset_nc: RBMDataSet): + """Test correct file path stem for DataServer folder type.""" + expected_path = Path(__file__).parent / "./data/ARASE/arase" + assert mock_dataset_nc._create_file_path_stem() == expected_path + + +def test_invalid_cadence_raises_nc(mock_dataset_nc: RBMDataSet): + """Invalid cadence should raise ValueError.""" + mock_dataset_nc._file_cadence = None + with pytest.raises(ValueError): + mock_dataset_nc._create_date_list() + + +def test_invalid_folder_type_raises_nc(mock_dataset_nc: RBMDataSet): + """Invalid folder type should raise ValueError.""" + mock_dataset_nc._folder_type = None + with pytest.raises(ValueError): + mock_dataset_nc._create_file_path_stem() + + +def test_get_var_method_nc(mock_dataset_nc: RBMDataSet): + """Test get_var returns correct variable.""" + mock_dataset_nc.Flux = np.array([4.0, 5.0]) + result = mock_dataset_nc.get_var(VariableEnum.FLUX) + assert isinstance(result, np.ndarray) + assert (result == np.array([4.0, 5.0])).all() + + +def test_load_variable_real_file_nc(): + start_time = dt.datetime(2025, 4, 1, tzinfo=dt.timezone.utc) + end_time = dt.datetime(2025, 4, 30, tzinfo=dt.timezone.utc) + + dataset = RBMDataSet( + start_time=start_time, + end_time=end_time, + folder_path=Path("path/to/real/files"), # this does not matter for the test + satellite=SatelliteEnum.GOESSecondary, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + preferred_extension="nc", + verbose=True, + ) + + dataset._load_variable(VariableEnum.ALPHA_LOCAL) + + assert hasattr(dataset, "alpha_local"), "Dataset should have 'alpha_local' attribute after loading." + assert isinstance(dataset.alpha_local, np.ndarray), "'alpha_local' should be a NumPy array." + assert hasattr(dataset, "FEDU") + + +def test_all_variables_in_dir_nc(mock_dataset_nc: RBMDataSet): + vars = [ + "datetime", + "time", + "energy_channels", + "alpha_local", + "alpha_eq_model", + "alpha_eq_real", + "InvMu", + "InvMu_real", + "InvK", + "InvV", + "Lstar", + "Flux", + "PSD", + "MLT", + "B_SM", + "B_total", + "B_sat", + "xGEO", + "P", + "R0", + "density", + ] + + for var in vars: + assert var in mock_dataset_nc.__dir__() + + +def test_load_all_variables_nc(mock_dataset_nc: RBMDataSet): + """Test that all variables can be loaded without error.""" + for var in VariableEnum: + try: + mock_dataset_nc._load_variable(var) + except Exception as e: + pytest.fail(f"Loading variable {var.var_name} raised an exception: {e}") + + +def test_load_variable_netcdf_caches_file_reads(mock_dataset_nc: RBMDataSet): + """Repeated NetCDF loads should reuse the parsed monthly file content.""" + datasets = { + "time": np.array([dt.datetime(2026, 3, 15, tzinfo=timezone.utc).timestamp()], dtype=np.int64), + "flux/alpha_local": np.array([[0.1, 0.2, 0.3]]), + "flux/FEDU": np.array([[1.0, 2.0, 3.0]]), + } + + with mock.patch( + "swvo.io.RBMDataSet.RBMDataSet._read_all_datasets_netcdf", + return_value=datasets, + ) as mock_read: + mock_dataset_nc._load_variable(VariableEnum.ALPHA_LOCAL) + mock_dataset_nc._load_variable(VariableEnum.FLUX) + + assert mock_read.call_count == 1 + np.testing.assert_array_equal(mock_dataset_nc.alpha_local, datasets["flux/alpha_local"]) + np.testing.assert_array_equal(mock_dataset_nc.Flux, datasets["flux/FEDU"]) + + +def test_get_loaded_variables_includes_computed_variables_nc(mock_dataset_nc: RBMDataSet): + """Computed variables should be tracked once accessed.""" + _ = mock_dataset_nc.P + _ = mock_dataset_nc.InvV + + loaded_variables = mock_dataset_nc.get_loaded_variables() + + assert "P" in loaded_variables + assert "InvV" in loaded_variables diff --git a/tests/io/RBMDataSet/test_RBMDatasetManager.py b/tests/io/RBMDataSet/test_RBMDatasetManager.py deleted file mode 100644 index b567e494..00000000 --- a/tests/io/RBMDataSet/test_RBMDatasetManager.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -import datetime as dt -from pathlib import Path - -import pytest - -from swvo.io.RBMDataSet import InstrumentEnum, MfmEnum, RBMDataSetManager, SatelliteEnum - - -@pytest.fixture -def manager_args(): - return { - "start_time": dt.datetime(2025, 4, 1, tzinfo=dt.timezone.utc), - "end_time": dt.datetime(2025, 4, 30, tzinfo=dt.timezone.utc), - "folder_path": Path("/mock/path"), - "instrument": InstrumentEnum.MAGED, - "mfm": MfmEnum.T89, - } - - -def test_singleton_prevents_direct_init(): - with pytest.raises(RuntimeError): - _ = RBMDataSetManager() - - -def test_single_satellite_returns_dataset(manager_args): - dataset = RBMDataSetManager.load( - satellite=SatelliteEnum.GOESSecondary, - **manager_args, - ) - assert dataset.get_satellite_name() == "secondary" - - -def test_same_parameters_return_same_instance(manager_args): - ds1 = RBMDataSetManager.load(satellite=SatelliteEnum.GOESSecondary, **manager_args) - ds2 = RBMDataSetManager.load(satellite=SatelliteEnum.GOESSecondary, **manager_args) - assert ds1 is ds2 - - -def test_different_satellite_returns_different_instance(manager_args): - ds1 = RBMDataSetManager.load(satellite=SatelliteEnum.GOESPrimary, **manager_args) - ds2 = RBMDataSetManager.load(satellite=SatelliteEnum.GOESSecondary, **manager_args) - assert ds1 is not ds2 - - -def test_string_input_for_satellite(manager_args): - dataset = RBMDataSetManager.load(satellite="GOESSecondary", **manager_args) - assert dataset.get_satellite_name() == "secondary" - - -def test_multiple_satellites_returns_list(manager_args): - datasets = RBMDataSetManager.load( - satellite=[SatelliteEnum.GOESPrimary, SatelliteEnum.GOESSecondary], - **manager_args, - ) - assert isinstance(datasets, list) - assert len(datasets) == 2 - assert all(isinstance(ds, type(datasets[0])) for ds in datasets) diff --git a/tests/io/RBMDataSet/test_RBMNcDataset.py b/tests/io/RBMDataSet/test_RBMNcDataset.py deleted file mode 100644 index c01190f6..00000000 --- a/tests/io/RBMDataSet/test_RBMNcDataset.py +++ /dev/null @@ -1,291 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -import datetime as dt -from datetime import timezone -from pathlib import Path -from unittest import mock - -import numpy as np -import pytest - -from swvo.io.RBMDataSet import ( - FileCadenceEnum, - InstrumentEnum, - MfmEnum, - RBMNcDataSet, - SatelliteEnum, - VariableEnum, -) - - -@pytest.fixture -def mock_module_string(): - return "swvo.io.RBMDataSet.RBMDataSet.RBMDataSet" - - -@pytest.fixture -def mock_dataset(mocker) -> RBMNcDataSet: - start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) - end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) - - mocker.patch( - "swvo.io.RBMDataSet.RBMNcDataSet._read_all_datasets_netcdf", - return_value={ - "time": np.array([dt.datetime(2023, 1, 15).timestamp()]), - "datetime": np.array([dt.datetime(2023, 1, 15, tzinfo=timezone.utc)]), - "flux/energy": np.array([100, 200, 300]), - "flux/alpha_local": np.array([0.1, 0.2, 0.3]), - "flux/FEDU": np.array([[1.0, 2.0, 3.0]]), - }, - ) - - dataset = RBMNcDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("/mock/path"), - satellite=SatelliteEnum.RBSPA, - instrument=InstrumentEnum.MAGEIS, - mfm=MfmEnum.T89, - verbose=False, - ) - - return dataset - - -def test_init_datetime_timezone(mock_module_string): - """Test timezone handling for input datetimes.""" - - start_time = dt.datetime(2023, 1, 1) - end_time = dt.datetime(2023, 1, 31) - - with ( - mock.patch(f"{mock_module_string}._create_date_list"), - mock.patch(f"{mock_module_string}._create_file_path_stem"), - mock.patch(f"{mock_module_string}._create_file_name_stem"), - ): - dataset = RBMNcDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("/mock/path"), - satellite=SatelliteEnum.RBSPA, - instrument=InstrumentEnum.MAGEIS, - mfm=MfmEnum.T89, - ) - - assert dataset._start_time.tzinfo == timezone.utc - assert dataset._end_time.tzinfo == timezone.utc - - -def test_get_satellite_name(mock_dataset: RBMNcDataSet): - """Test get_satellite_name method.""" - assert mock_dataset.get_satellite_name() == "rbspa" - - -def test_get_satellite_and_instrument_name(mock_dataset: RBMNcDataSet): - """Test get_satellite_and_instrument_name method.""" - assert mock_dataset.get_satellite_and_instrument_name() == "rbspa_mageis" - - -def test_get_print_name(mock_dataset: RBMNcDataSet): - """Test get_print_name method.""" - assert mock_dataset.get_print_name() == "rbspa mageis" - - -def test_satellite_string_input(mock_module_string): - """Test that satellite can be provided as string.""" - with mock.patch(f"{mock_module_string}._create_date_list"): - with mock.patch(f"{mock_module_string}._create_file_path_stem"): - with mock.patch(f"{mock_module_string}._create_file_name_stem"): - dataset = RBMNcDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), - satellite="RBSPA", - instrument=InstrumentEnum.MAGEIS, - mfm=MfmEnum.T89, - ) - - assert dataset._satellite == SatelliteEnum.RBSPA - - -def test_getattr_with_valid_variable(mock_dataset: RBMNcDataSet): - """Test __getattr__ with a valid variable.""" - with mock.patch.object(mock_dataset, "_load_variable") as _: - mock_dataset.Flux = np.array([1.0, 2.0, 3.0]) - result = mock_dataset.Flux - assert isinstance(result, np.ndarray) - assert (result == np.array([1.0, 2.0, 3.0])).all() - - -def test_getattr_with_invalid_variable(mock_dataset: RBMNcDataSet): - """Test __getattr__ with an invalid variable.""" - with pytest.raises(AttributeError): - _ = mock_dataset.NonExistentAttribute - - -def test_getattr_with_similar_variable(mock_dataset: RBMNcDataSet): - """Test __getattr__ suggests similar variable name.""" - with pytest.raises(AttributeError) as e: - _ = mock_dataset.Flx - - assert "Maybe you meant Flux?" in str(e.value) - - -def test_computed_invv_variable(mock_dataset: RBMNcDataSet): - """Test computed InvV variable.""" - - mock_dataset.InvK = np.array([[1.0, 2.0]]) - mock_dataset.InvMu = np.array([[0.1, 0.2], [0.3, 0.4]]) - - mock_dataset._load_variable(VariableEnum.INV_V) - - expected = ( - mock_dataset.InvMu - * (np.repeat(mock_dataset.InvK[:, np.newaxis, :], mock_dataset.InvMu.shape[1], axis=1) + 0.5) ** 2 - ) - np.testing.assert_array_equal(mock_dataset.InvV, expected) - - -def test_computed_p_variable(mock_dataset: RBMNcDataSet): - """Test computed P variable.""" - - mock_dataset.MLT = np.array([0.0, 6.0, 12.0, 18.0]) - - mock_dataset._load_variable(VariableEnum.P) - - expected = ((mock_dataset.MLT + 12) / 12 * np.pi) % (2 * np.pi) - np.testing.assert_array_equal(mock_dataset.P, expected) - - -@pytest.mark.parametrize("satellite", list(SatelliteEnum)) -def test_all_satellites_work(satellite, mock_module_string): - """Ensure all SatelliteEnum values initialize without error.""" - with mock.patch(f"{mock_module_string}._create_date_list"): - with mock.patch(f"{mock_module_string}._create_file_path_stem"): - with mock.patch(f"{mock_module_string}._create_file_name_stem"): - dataset = RBMNcDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), - satellite=satellite, - instrument=InstrumentEnum.HOPE, - mfm=MfmEnum.T89, - ) - assert dataset._satellite == satellite - - -@pytest.mark.parametrize("instrument", list(InstrumentEnum)) -def test_all_instruments_work(instrument, mock_module_string): - """Ensure all InstrumentEnum values initialize without error.""" - with mock.patch(f"{mock_module_string}._create_date_list"): - with mock.patch(f"{mock_module_string}._create_file_path_stem"): - with mock.patch(f"{mock_module_string}._create_file_name_stem"): - dataset = RBMNcDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), - satellite=SatelliteEnum.RBSPA, - instrument=instrument, - mfm=MfmEnum.T89, - ) - assert dataset._instrument == instrument - - -def test_create_date_list_monthly(mock_dataset: RBMNcDataSet): - """Test monthly cadence date generation.""" - mock_dataset.set_file_cadence(FileCadenceEnum.Monthly) - date_list = mock_dataset._create_date_list() - assert date_list[0].month == 1 - assert all(date.tzinfo == timezone.utc for date in date_list) - - -def test_create_date_list_daily(mock_dataset: RBMNcDataSet): - """Test daily cadence date generation.""" - mock_dataset.set_file_cadence(FileCadenceEnum.Daily) - date_list = mock_dataset._create_date_list() - assert len(date_list) > 20 - assert all(date.tzinfo == timezone.utc for date in date_list) - - -def test_file_name_stem_generation(mock_dataset: RBMNcDataSet): - """Test that file name stem is generated correctly.""" - assert mock_dataset._create_file_name_stem() == "rbspa_mageis_" - - -def test_file_path_stem_dataserver(mock_dataset: RBMNcDataSet): - """Test correct file path stem for DataServer folder type.""" - expected_path = Path("/mock/path/RBSP/rbspa/") - assert mock_dataset._create_file_path_stem() == expected_path - - -def test_invalid_cadence_raises(mock_dataset: RBMNcDataSet): - """Invalid cadence should raise ValueError.""" - mock_dataset._file_cadence = None - with pytest.raises(ValueError): - mock_dataset._create_date_list() - - -def test_invalid_folder_type_raises(mock_dataset: RBMNcDataSet): - """Invalid folder type should raise ValueError.""" - mock_dataset._folder_type = None - with pytest.raises(ValueError): - mock_dataset._create_file_path_stem() - - -def test_get_var_method(mock_dataset: RBMNcDataSet): - """Test get_var returns correct variable.""" - mock_dataset.Flux = np.array([4.0, 5.0]) - result = mock_dataset.get_var(VariableEnum.FLUX) - assert isinstance(result, np.ndarray) - assert (result == np.array([4.0, 5.0])).all() - - -def test_load_variable_real_file(): - start_time = dt.datetime(2025, 4, 1, tzinfo=dt.timezone.utc) - end_time = dt.datetime(2025, 4, 30, tzinfo=dt.timezone.utc) - - dataset = RBMNcDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("path/to/real/files"), # this does not matter for the test - satellite=SatelliteEnum.GOESSecondary, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - verbose=True, - ) - - dataset._load_variable(VariableEnum.ALPHA_LOCAL) - - assert hasattr(dataset, "alpha_local"), "Dataset should have 'alpha_local' attribute after loading." - assert isinstance(dataset.alpha_local, np.ndarray), "'alpha_local' should be a NumPy array." - assert hasattr(dataset, "FEDU") - -def test_all_variables_in_dir(mock_dataset: RBMNcDataSet): - vars = [ - "datetime", - "time", - "energy_channels", - "alpha_local", - "alpha_eq_model", - "alpha_eq_real", - "InvMu", - "InvMu_real", - "InvK", - "InvV", - "Lstar", - "Flux", - "PSD", - "MLT", - "B_SM", - "B_total", - "B_sat", - "xGEO", - "P", - "R0", - "density", - ] - - for var in vars: - assert var in mock_dataset.__dir__()