diff --git a/hashprep/__init__.py b/hashprep/__init__.py index 6534f0d..3d162f3 100644 --- a/hashprep/__init__.py +++ b/hashprep/__init__.py @@ -1,3 +1,5 @@ +from .config import HashPrepConfig as HashPrepConfig from .core.analyzer import DatasetAnalyzer as DatasetAnalyzer +from .utils.config_loader import load_config as load_config __version__ = "0.1.0b2" diff --git a/hashprep/checks/__init__.py b/hashprep/checks/__init__.py index 9e5daaa..672d274 100644 --- a/hashprep/checks/__init__.py +++ b/hashprep/checks/__init__.py @@ -29,7 +29,13 @@ def _check_dataset_drift(analyzer): """Wrapper for drift detection that uses analyzer's comparison_df.""" if hasattr(analyzer, "comparison_df") and analyzer.comparison_df is not None: - return check_drift(analyzer.df, analyzer.comparison_df) + drift_cfg = analyzer.config.drift + return check_drift( + analyzer.df, + analyzer.comparison_df, + threshold=drift_cfg.p_value, + config=drift_cfg, + ) return [] diff --git a/hashprep/checks/columns.py b/hashprep/checks/columns.py index 55881cc..48d92f3 100644 --- a/hashprep/checks/columns.py +++ b/hashprep/checks/columns.py @@ -1,8 +1,5 @@ -from ..config import DEFAULT_CONFIG from .core import Issue -_COL_THRESHOLDS = DEFAULT_CONFIG.columns - def _check_single_value_columns(analyzer): issues = [] @@ -28,18 +25,15 @@ def _check_single_value_columns(analyzer): return issues -def _check_high_cardinality( - analyzer, - threshold: int = _COL_THRESHOLDS.high_cardinality_count, - critical_threshold: float = _COL_THRESHOLDS.high_cardinality_ratio_critical, -): +def _check_high_cardinality(analyzer): + _cfg = analyzer.config.columns issues = [] categorical_cols = analyzer.df.select_dtypes(include="object").columns.tolist() for col in categorical_cols: unique_count = int(analyzer.df[col].nunique()) unique_ratio = float(unique_count / len(analyzer.df)) - if unique_count > threshold: - severity = "critical" if unique_ratio > critical_threshold else "warning" + if unique_count > _cfg.high_cardinality_count: + severity = "critical" if unique_ratio > _cfg.high_cardinality_ratio_critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Drop column: Avoids overfitting from unique identifiers (Pros: Simplifies model; Cons: Loses potential info).\n- Engineer feature: Extract patterns (e.g., titles from names) (Pros: Retains useful info; Cons: Requires domain knowledge).\n- Use hashing: Reduce dimensionality (Pros: Scalable; Cons: May lose interpretability)." @@ -61,10 +55,11 @@ def _check_high_cardinality( def _check_duplicates(analyzer): issues = [] + _cfg = analyzer.config.columns duplicate_rows = int(analyzer.df.duplicated().sum()) if duplicate_rows > 0: duplicate_ratio = float(duplicate_rows / len(analyzer.df)) - severity = "critical" if duplicate_ratio > _COL_THRESHOLDS.duplicate_ratio_critical else "warning" + severity = "critical" if duplicate_ratio > _cfg.duplicate_ratio_critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Drop duplicates: Ensures data integrity (Pros: Cleaner data; Cons: May lose valid repeats).\n- Verify duplicates: Check if intentional (e.g., time-series) (Pros: Validates data; Cons: Time-consuming)." diff --git a/hashprep/checks/correlations.py b/hashprep/checks/correlations.py index 0955ba3..665793f 100644 --- a/hashprep/checks/correlations.py +++ b/hashprep/checks/correlations.py @@ -4,16 +4,10 @@ import pandas as pd from scipy.stats import chi2_contingency, kendalltau, pearsonr, spearmanr -from ..config import DEFAULT_CONFIG from ..utils.type_inference import is_usable_for_corr from .core import Issue from .discretizer import DiscretizationType, Discretizer -_CORR = DEFAULT_CONFIG.correlations -CORR_THRESHOLDS = _CORR.as_nested_dict() -CAT_MAX_DISTINCT = _CORR.max_distinct_categories -LOW_CARD_NUM_THRESHOLD = _CORR.low_cardinality_numeric - def _cramers_v_corrected(table: pd.DataFrame) -> float: if table.empty or (table.shape[0] == 1 or table.shape[1] == 1): @@ -37,8 +31,9 @@ def calculate_correlations(analyzer, thresholds=None): Compute correlations using internal defaults: Spearman + Pearson for numerics, with Kendall added automatically for low-cardinality pairs. """ + _cfg = analyzer.config.correlations if thresholds is None: - thresholds = CORR_THRESHOLDS + thresholds = _cfg.as_nested_dict() inferred_types = analyzer.column_types # Use analyzer.column_types for inferred types dict issues = [] @@ -50,7 +45,7 @@ def calculate_correlations(analyzer, thresholds=None): col for col, typ in inferred_types.items() if typ == "Categorical" - and 1 < analyzer.df[col].nunique() <= CAT_MAX_DISTINCT + and 1 < analyzer.df[col].nunique() <= _cfg.max_distinct_categories and is_usable_for_corr(analyzer.df[col]) ] @@ -62,6 +57,7 @@ def calculate_correlations(analyzer, thresholds=None): def _check_numeric_correlation(analyzer, numeric_cols: list, thresholds: dict): + _cfg = analyzer.config.correlations issues = [] if len(numeric_cols) < 2: return issues @@ -85,7 +81,9 @@ def _check_numeric_correlation(analyzer, numeric_cols: list, thresholds: dict): # Kendall (only for low-cardinality numerics) kendall_corr, kendall_p = None, None - is_low_card = series1.nunique() <= LOW_CARD_NUM_THRESHOLD or series2.nunique() <= LOW_CARD_NUM_THRESHOLD + is_low_card = ( + series1.nunique() <= _cfg.low_cardinality_numeric or series2.nunique() <= _cfg.low_cardinality_numeric + ) if is_low_card: kendall_corr, kendall_p = kendalltau(series1, series2) kendall_corr = abs(kendall_corr) diff --git a/hashprep/checks/datetime_checks.py b/hashprep/checks/datetime_checks.py index 6c5abfd..dc45717 100644 --- a/hashprep/checks/datetime_checks.py +++ b/hashprep/checks/datetime_checks.py @@ -1,11 +1,8 @@ import numpy as np import pandas as pd -from ..config import DEFAULT_CONFIG from .core import Issue -_DT_CFG = DEFAULT_CONFIG.datetime - def _coerce_datetime(series: pd.Series) -> pd.Series: """Return a datetime Series regardless of whether the source is datetime64 or object.""" @@ -21,6 +18,7 @@ def _datetime_cols(analyzer) -> list[str]: def _check_datetime_future_dates(analyzer) -> list[Issue]: """Flag datetime columns that contain values in the future (likely data errors).""" + _cfg = analyzer.config.datetime issues = [] now = pd.Timestamp.now() @@ -34,7 +32,7 @@ def _check_datetime_future_dates(analyzer) -> list[Issue]: continue future_ratio = future_count / len(dt) - severity = "critical" if future_ratio > _DT_CFG.future_date_critical_ratio else "warning" + severity = "critical" if future_ratio > _cfg.future_date_critical_ratio else "warning" impact = "high" if severity == "critical" else "medium" issues.append( Issue( @@ -59,11 +57,12 @@ def _check_datetime_future_dates(analyzer) -> list[Issue]: def _check_datetime_gaps(analyzer) -> list[Issue]: """Detect anomalously large gaps in datetime columns (broken time series).""" + _cfg = analyzer.config.datetime issues = [] for col in _datetime_cols(analyzer): dt = _coerce_datetime(analyzer.df[col]).sort_values() - if len(dt) < _DT_CFG.min_rows_for_gap_check: + if len(dt) < _cfg.min_rows_for_gap_check: continue diffs = dt.diff().dropna() @@ -79,8 +78,8 @@ def _check_datetime_gaps(analyzer) -> list[Issue]: max_gap = float(diff_seconds.max()) ratio = max_gap / median_gap - if ratio >= _DT_CFG.gap_multiplier_warning: - severity = "critical" if ratio >= _DT_CFG.gap_multiplier_critical else "warning" + if ratio >= _cfg.gap_multiplier_warning: + severity = "critical" if ratio >= _cfg.gap_multiplier_critical else "warning" impact = "high" if severity == "critical" else "medium" # Locate the gap for a human-readable description @@ -113,11 +112,12 @@ def _check_datetime_gaps(analyzer) -> list[Issue]: def _check_datetime_monotonicity(analyzer) -> list[Issue]: """Warn when a datetime column that looks like a time-series index is non-monotonic.""" + _cfg = analyzer.config.datetime issues = [] for col in _datetime_cols(analyzer): dt = _coerce_datetime(analyzer.df[col]) - if len(dt) < _DT_CFG.min_rows_for_gap_check: + if len(dt) < _cfg.min_rows_for_gap_check: continue # Only flag if the column has mostly unique values (i.e., likely an index/timestamp) diff --git a/hashprep/checks/distribution.py b/hashprep/checks/distribution.py index 7446d5e..d2a0d65 100644 --- a/hashprep/checks/distribution.py +++ b/hashprep/checks/distribution.py @@ -1,21 +1,19 @@ from scipy.stats import kstest -from ..config import DEFAULT_CONFIG from .core import Issue -_DIST = DEFAULT_CONFIG.distribution - -def _check_uniform_distribution(analyzer, p_threshold: float = _DIST.uniform_p_value) -> list[Issue]: +def _check_uniform_distribution(analyzer) -> list[Issue]: """ Detect uniformly distributed numeric columns using Kolmogorov-Smirnov test. Uniform distributions often indicate synthetic IDs or sequential data. """ + _cfg = analyzer.config.distribution issues = [] for col in analyzer.df.select_dtypes(include="number").columns: series = analyzer.df[col].dropna() - if len(series) < _DIST.uniform_min_samples: + if len(series) < _cfg.uniform_min_samples: continue min_val, max_val = series.min(), series.max() @@ -26,7 +24,7 @@ def _check_uniform_distribution(analyzer, p_threshold: float = _DIST.uniform_p_v _, p_val = kstest(normalized, "uniform") is_monotonic = series.is_monotonic_increasing or series.is_monotonic_decreasing - if p_val > p_threshold or is_monotonic: + if p_val > _cfg.uniform_p_value or is_monotonic: monotonic_note = " and monotonic" if is_monotonic else "" issues.append( Issue( @@ -47,22 +45,23 @@ def _check_uniform_distribution(analyzer, p_threshold: float = _DIST.uniform_p_v return issues -def _check_unique_values(analyzer, threshold: float = _DIST.unique_value_ratio) -> list[Issue]: +def _check_unique_values(analyzer) -> list[Issue]: """ Detect columns where nearly all values are unique. High uniqueness often indicates identifiers, names, or free-text fields. """ + _cfg = analyzer.config.distribution issues = [] for col in analyzer.df.columns: series = analyzer.df[col].dropna() - if len(series) < _DIST.unique_min_samples: + if len(series) < _cfg.unique_min_samples: continue unique_count = series.nunique() unique_ratio = unique_count / len(series) - if unique_ratio >= threshold: + if unique_ratio >= _cfg.unique_value_ratio: issues.append( Issue( category="unique_values", diff --git a/hashprep/checks/drift.py b/hashprep/checks/drift.py index 332a4bc..db84b86 100644 --- a/hashprep/checks/drift.py +++ b/hashprep/checks/drift.py @@ -9,14 +9,13 @@ _log = get_logger("checks.drift") _DRIFT = DEFAULT_CONFIG.drift -CRITICAL_P_VALUE = _DRIFT.critical_p_value -MAX_CATEGORIES_FOR_CHI2 = _DRIFT.max_categories_for_chi2 def check_drift( df_train: pd.DataFrame, df_test: pd.DataFrame, threshold: float = _DRIFT.p_value, + config=None, ) -> list[Issue]: """ Check for distribution shift between two datasets. @@ -25,10 +24,11 @@ def check_drift( if not isinstance(df_train, pd.DataFrame) or not isinstance(df_test, pd.DataFrame): raise TypeError("Both df_train and df_test must be pandas DataFrames") + drift_cfg = config if config is not None else _DRIFT issues = [] - issues.extend(_check_numeric_drift(df_train, df_test, threshold)) - issues.extend(_check_categorical_drift(df_train, df_test, threshold)) + issues.extend(_check_numeric_drift(df_train, df_test, threshold, drift_cfg)) + issues.extend(_check_categorical_drift(df_train, df_test, threshold, drift_cfg)) return issues @@ -37,6 +37,7 @@ def _check_numeric_drift( df_train: pd.DataFrame, df_test: pd.DataFrame, threshold: float, + drift_cfg, ) -> list[Issue]: """Check numeric columns for distribution drift using KS-test.""" issues = [] @@ -55,7 +56,7 @@ def _check_numeric_drift( stat, p_val = ks_2samp(train_vals, test_vals) if p_val < threshold: - severity = "critical" if p_val < CRITICAL_P_VALUE else "warning" + severity = "critical" if p_val < drift_cfg.critical_p_value else "warning" issues.append( Issue( category="dataset_drift", @@ -74,6 +75,7 @@ def _check_categorical_drift( df_train: pd.DataFrame, df_test: pd.DataFrame, threshold: float, + drift_cfg, ) -> list[Issue]: """Check categorical columns for distribution drift using Chi-square test.""" issues = [] @@ -88,20 +90,20 @@ def _check_categorical_drift( new_categories = set(test_counts.index) - set(train_counts.index) if new_categories: - sample_new = list(new_categories)[: _DRIFT.max_new_category_samples] + sample_new = list(new_categories)[: drift_cfg.max_new_category_samples] issues.append( Issue( category="dataset_drift", severity="warning", column=col, - description=f"New categories in test set for '{col}': {sample_new}{'...' if len(new_categories) > _DRIFT.max_new_category_samples else ''}", + description=f"New categories in test set for '{col}': {sample_new}{'...' if len(new_categories) > drift_cfg.max_new_category_samples else ''}", impact_score="medium", quick_fix="Handle unseen categories in preprocessing pipeline (e.g., OrdinalEncoder with unknown_value).", ) ) all_cats = list(set(train_counts.index) | set(test_counts.index)) - if len(all_cats) > MAX_CATEGORIES_FOR_CHI2: + if len(all_cats) > drift_cfg.max_categories_for_chi2: continue train_total = train_counts.sum() @@ -127,7 +129,7 @@ def _check_categorical_drift( chi2_stat, p_val = chisquare(observed_arr, f_exp=expected_arr) if p_val < threshold: - severity = "critical" if p_val < CRITICAL_P_VALUE else "warning" + severity = "critical" if p_val < drift_cfg.critical_p_value else "warning" issues.append( Issue( category="dataset_drift", diff --git a/hashprep/checks/imbalance.py b/hashprep/checks/imbalance.py index 19a2cd3..94d9e8c 100644 --- a/hashprep/checks/imbalance.py +++ b/hashprep/checks/imbalance.py @@ -1,8 +1,8 @@ -from ..config import DEFAULT_CONFIG from .core import Issue -def _check_class_imbalance(analyzer, threshold: float = DEFAULT_CONFIG.imbalance.majority_class_ratio): +def _check_class_imbalance(analyzer): + threshold = analyzer.config.imbalance.majority_class_ratio issues = [] if analyzer.target_col and analyzer.target_col in analyzer.df.columns: counts = analyzer.df[analyzer.target_col].value_counts(normalize=True) diff --git a/hashprep/checks/leakage.py b/hashprep/checks/leakage.py index c98e650..a0764cf 100644 --- a/hashprep/checks/leakage.py +++ b/hashprep/checks/leakage.py @@ -2,11 +2,9 @@ import pandas as pd from scipy.stats import chi2_contingency, f_oneway -from ..config import DEFAULT_CONFIG from ..utils.logging import get_logger from .core import Issue -_LEAK = DEFAULT_CONFIG.leakage _log = get_logger("checks.leakage") _LEAKAGE_CRITICAL_FIX = ( @@ -42,6 +40,7 @@ def _check_data_leakage(analyzer): def _check_target_leakage_patterns(analyzer): + _leak = analyzer.config.leakage issues = [] if analyzer.target_col and analyzer.target_col in analyzer.df.columns: target = analyzer.df[analyzer.target_col] @@ -55,9 +54,9 @@ def _check_target_leakage_patterns(analyzer): for col, corr in corrs.items(): severity = ( "critical" - if corr > _LEAK.numeric_critical + if corr > _leak.numeric_critical else "warning" - if corr > _LEAK.numeric_warning + if corr > _leak.numeric_warning else None ) if severity: @@ -86,9 +85,9 @@ def _check_target_leakage_patterns(analyzer): cramers_v = np.sqrt(phi2 / min(k - 1, r - 1)) severity = ( "critical" - if cramers_v > _LEAK.categorical_critical + if cramers_v > _leak.categorical_critical else "warning" - if cramers_v > _LEAK.categorical_warning + if cramers_v > _leak.categorical_warning else None ) if severity: @@ -122,9 +121,9 @@ def _check_target_leakage_patterns(analyzer): f_stat, p_val = f_oneway(*groups) severity = ( "critical" - if f_stat > _LEAK.f_stat_critical and p_val < _LEAK.f_stat_p_value + if f_stat > _leak.f_stat_critical and p_val < _leak.f_stat_p_value else "warning" - if f_stat > _LEAK.f_stat_warning and p_val < _LEAK.f_stat_p_value + if f_stat > _leak.f_stat_warning and p_val < _leak.f_stat_p_value else None ) if severity: diff --git a/hashprep/checks/missing_values.py b/hashprep/checks/missing_values.py index 58e758c..63c513a 100644 --- a/hashprep/checks/missing_values.py +++ b/hashprep/checks/missing_values.py @@ -14,14 +14,13 @@ _THRESHOLDS = DEFAULT_CONFIG.missing_values -def _check_high_missing_values( - analyzer, threshold: float = _THRESHOLDS.warning, critical_threshold: float = _THRESHOLDS.critical -): +def _check_high_missing_values(analyzer): + _cfg = analyzer.config.missing_values issues = [] for col in analyzer.df.columns: missing_pct = float(analyzer.df[col].isna().mean()) - if missing_pct > threshold: - severity = "critical" if missing_pct > critical_threshold else "warning" + if missing_pct > _cfg.warning: + severity = "critical" if missing_pct > _cfg.critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Drop column: Reduces bias from missing data (Pros: Simplifies model; Cons: Loses potential info).\n- Impute values: Use domain-informed methods (e.g., median, mode, or predictive model) (Pros: Retains feature; Cons: May introduce bias).\n- Create missingness indicator: Flag missing values as a new feature (Pros: Captures missingness pattern; Cons: Adds complexity)." @@ -58,15 +57,12 @@ def _check_empty_columns(analyzer): return issues -def _check_dataset_missingness( - analyzer, - threshold: float = _THRESHOLDS.dataset_warning_pct, - critical_threshold: float = _THRESHOLDS.dataset_critical_pct, -): +def _check_dataset_missingness(analyzer): + _cfg = analyzer.config.missing_values issues = [] missing_pct = float((analyzer.df.isnull().sum().sum() / (analyzer.df.shape[0] * analyzer.df.shape[1])) * 100) - if missing_pct > threshold: - severity = "critical" if missing_pct > critical_threshold else "warning" + if missing_pct > _cfg.dataset_warning_pct: + severity = "critical" if missing_pct > _cfg.dataset_critical_pct else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Drop sparse columns: Reduces bias from missingness (Pros: Simplifies model; Cons: Loses info).\n- Impute globally: Use advanced methods (e.g., predictive models) (Pros: Retains features; Cons: Risk of bias).\n- Investigate source: Check data collection issues (Pros: Improves quality; Cons: Time-consuming)." @@ -86,16 +82,13 @@ def _check_dataset_missingness( return issues -def _check_missing_patterns( - analyzer, - threshold: float = _THRESHOLDS.pattern_p_value, - critical_p_threshold: float = _THRESHOLDS.pattern_critical_p_value, -): +def _check_missing_patterns(analyzer): + _cfg = analyzer.config.missing_values + threshold = _cfg.pattern_p_value + critical_p_threshold = _cfg.pattern_critical_p_value issues = [] missing_cols = [ - col - for col in analyzer.df.columns - if int(analyzer.df[col].isna().sum()) >= _THRESHOLDS.pattern_min_missing_count + col for col in analyzer.df.columns if int(analyzer.df[col].isna().sum()) >= _cfg.pattern_min_missing_count ] # grouping logic @@ -108,7 +101,7 @@ def _check_missing_patterns( continue try: value_counts = analyzer.df[other_col].value_counts() - rare_cats = value_counts[value_counts < _THRESHOLDS.pattern_rare_category_count].index + rare_cats = value_counts[value_counts < _cfg.pattern_rare_category_count].index temp_col = analyzer.df[other_col].copy() if not rare_cats.empty: temp_col = temp_col.where(~temp_col.isin(rare_cats), "Other") @@ -131,7 +124,7 @@ def cramers_v(table): return np.sqrt(phi2corr / rkcorr) cramers = cramers_v(table) - if p_val < threshold and cramers > _THRESHOLDS.pattern_cramers_v_min: + if p_val < threshold and cramers > _cfg.pattern_cramers_v_min: cat_patterns[col].append((other_col, p_val, cramers)) except (ValueError, LinAlgError) as e: _log.debug("Chi-square test failed for '%s' vs '%s': %s", col, other_col, e) @@ -143,10 +136,7 @@ def cramers_v(table): try: missing = analyzer.df[analyzer.df[col].isna()][other_col].dropna() non_missing = analyzer.df[analyzer.df[col].notna()][other_col].dropna() - if ( - len(missing) < _THRESHOLDS.pattern_min_group_size - or len(non_missing) < _THRESHOLDS.pattern_min_group_size - ): + if len(missing) < _cfg.pattern_min_group_size or len(non_missing) < _cfg.pattern_min_group_size: continue # Replaced f_oneway with mannwhitneyu @@ -156,7 +146,7 @@ def cramers_v(table): pooled_std = np.sqrt((np.std(missing) ** 2 + np.std(non_missing) ** 2) / 2) cohens_d = abs(np.mean(missing) - np.mean(non_missing)) / pooled_std if pooled_std > 0 else 0 - if p_val < threshold and cohens_d > _THRESHOLDS.pattern_cohens_d_min: + if p_val < threshold and cohens_d > _cfg.pattern_cohens_d_min: num_patterns[col].append((other_col, p_val, cohens_d)) except (ValueError, RuntimeWarning) as e: _log.debug("Mann-Whitney U test failed for '%s' vs '%s': %s", col, other_col, e) @@ -173,7 +163,7 @@ def cramers_v(table): if all_patterns: # Sort by effect size (descending) and take top 3 all_patterns.sort(key=lambda x: x[2], reverse=True) # x[2] is effect size - top_corrs = [pat[0] for pat in all_patterns[: _THRESHOLDS.pattern_top_correlations]] + top_corrs = [pat[0] for pat in all_patterns[: _cfg.pattern_top_correlations]] total_count = len(all_patterns) desc = f"Missingness in '{col}' correlates with {total_count} columns ({', '.join(top_corrs)})" @@ -183,9 +173,7 @@ def cramers_v(table): is_target_correlated = any(pat[0] == analyzer.target_col for pat in all_patterns) severity = ( "critical" - if p_val < critical_p_threshold - and is_target_correlated - and max_effect > _THRESHOLDS.pattern_effect_critical + if p_val < critical_p_threshold and is_target_correlated and max_effect > _cfg.pattern_effect_critical else "warning" ) impact = "high" if severity == "critical" else "medium" diff --git a/hashprep/checks/mutual_info.py b/hashprep/checks/mutual_info.py index 625b69c..f34e45a 100644 --- a/hashprep/checks/mutual_info.py +++ b/hashprep/checks/mutual_info.py @@ -4,12 +4,9 @@ and is likely useless (or worse — noise) for a predictive model. """ -from ..config import DEFAULT_CONFIG from ..summaries.mutual_info import summarize_mutual_information from .core import Issue -_MI = DEFAULT_CONFIG.mutual_info - def _check_low_mutual_information(analyzer) -> list[Issue]: """ @@ -23,12 +20,13 @@ def _check_low_mutual_information(analyzer) -> list[Issue]: if not mi_result or not mi_result.get("scores"): return [] + _cfg = analyzer.config.mutual_info issues = [] scores = mi_result["scores"] task = mi_result["task"] for col, score in scores.items(): - if score < _MI.low_mi_warning: + if score < _cfg.low_mi_warning: issues.append( Issue( category="low_mutual_information", diff --git a/hashprep/checks/outliers.py b/hashprep/checks/outliers.py index 1cd3217..521a5f0 100644 --- a/hashprep/checks/outliers.py +++ b/hashprep/checks/outliers.py @@ -7,17 +7,18 @@ _THRESHOLDS = DEFAULT_CONFIG.outliers -def _check_outliers(analyzer, z_threshold: float = _THRESHOLDS.z_score): +def _check_outliers(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="number").columns: series = analyzer.df[col].dropna() if len(series) == 0: continue z_scores = (series - series.mean()) / series.std(ddof=0) - outlier_count = int((abs(z_scores) > z_threshold).sum()) + outlier_count = int((abs(z_scores) > _cfg.z_score).sum()) if outlier_count > 0: outlier_ratio = float(outlier_count / len(series)) - severity = "critical" if outlier_ratio > _THRESHOLDS.outlier_ratio_critical else "warning" + severity = "critical" if outlier_ratio > _cfg.outlier_ratio_critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Remove outliers: Improves model stability (Pros: Reduces noise; Cons: Loses data).\n- Winsorize: Cap extreme values (Pros: Retains data; Cons: Alters distribution).\n- Transform: Apply log/sqrt to reduce impact (Pros: Preserves info; Cons: Changes interpretation)." @@ -37,19 +38,16 @@ def _check_outliers(analyzer, z_threshold: float = _THRESHOLDS.z_score): return issues -def _check_high_zero_counts( - analyzer, - threshold: float = _THRESHOLDS.zero_count_warning, - critical_threshold: float = _THRESHOLDS.zero_count_critical, -): +def _check_high_zero_counts(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="number").columns: series = analyzer.df[col].dropna() if len(series) == 0: continue zero_pct = float((series == 0).mean()) - if zero_pct > threshold: - severity = "critical" if zero_pct > critical_threshold else "warning" + if zero_pct > _cfg.zero_count_warning: + severity = "critical" if zero_pct > _cfg.zero_count_critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Drop column: If zeros are not meaningful (Pros: Simplifies model; Cons: Loses info).\n- Transform: Use binary indicator or log transform (Pros: Retains info; Cons: Changes interpretation).\n- Verify zeros: Check if valid or errors (Pros: Ensures accuracy; Cons: Time-consuming)." @@ -69,18 +67,17 @@ def _check_high_zero_counts( return issues -def _check_extreme_text_lengths( - analyzer, max_threshold: int = _THRESHOLDS.text_length_max, min_threshold: int = _THRESHOLDS.text_length_min -): +def _check_extreme_text_lengths(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="object").columns: series = analyzer.df[col].dropna().astype(str) if series.empty: continue lengths = series.str.len() - if lengths.max() > max_threshold or lengths.min() < min_threshold: - extreme_ratio = float(((lengths > max_threshold) | (lengths < min_threshold)).mean()) - severity = "critical" if extreme_ratio > _THRESHOLDS.extreme_ratio_critical else "warning" + if lengths.max() > _cfg.text_length_max or lengths.min() < _cfg.text_length_min: + extreme_ratio = float(((lengths > _cfg.text_length_max) | (lengths < _cfg.text_length_min)).mean()) + severity = "critical" if extreme_ratio > _cfg.extreme_ratio_critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Truncate values: Cap extreme lengths (Pros: Stabilizes model; Cons: Loses info).\n- Filter outliers: Remove extreme entries (Pros: Reduces noise; Cons: Loses data).\n- Transform: Normalize lengths (e.g., log) (Pros: Retains info; Cons: Changes interpretation)." @@ -100,21 +97,18 @@ def _check_extreme_text_lengths( return issues -def _check_skewness( - analyzer, - skew_threshold: float = _THRESHOLDS.skewness_warning, - critical_skew_threshold: float = _THRESHOLDS.skewness_critical, -): +def _check_skewness(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="number").columns: series = analyzer.df[col].dropna() - if len(series) < _THRESHOLDS.min_sample_size: + if len(series) < _cfg.min_sample_size: continue skewness = float(series.skew()) abs_skew = abs(skewness) - if abs_skew > skew_threshold: - severity = "critical" if abs_skew > critical_skew_threshold else "warning" + if abs_skew > _cfg.skewness_warning: + severity = "critical" if abs_skew > _cfg.skewness_critical else "warning" impact = "high" if severity == "critical" else "medium" quick_fix = ( "Options: \n- Log transformation: Handles right skew (Pros: Normalizes; Cons: Only for positive).\n- Box-Cox/Yeo-Johnson: General power transforms (Pros: Robust; Cons: More complex).\n- Retain: Some models (trees) handle skewness well." @@ -134,14 +128,15 @@ def _check_skewness( return issues -def _check_datetime_skew(analyzer, threshold: float = _THRESHOLDS.datetime_skew): +def _check_datetime_skew(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="datetime64").columns: series = pd.to_datetime(analyzer.df[col], errors="coerce").dropna() if series.empty: continue year_counts = series.dt.year.value_counts(normalize=True) - if year_counts.max() > threshold: + if year_counts.max() > _cfg.datetime_skew: issues.append( Issue( category="datetime_skew", @@ -155,15 +150,15 @@ def _check_datetime_skew(analyzer, threshold: float = _THRESHOLDS.datetime_skew) return issues -def _check_infinite_values(analyzer, threshold: float = _THRESHOLDS.infinite_ratio_critical): - """Detect columns with infinite values.""" +def _check_infinite_values(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="number").columns: series = analyzer.df[col] inf_count = int(np.isinf(series).sum()) if inf_count > 0: inf_ratio = inf_count / len(series) - severity = "critical" if inf_ratio > threshold else "warning" + severity = "critical" if inf_ratio > _cfg.infinite_ratio_critical else "warning" impact = "high" if severity == "critical" else "medium" issues.append( Issue( @@ -183,17 +178,17 @@ def _check_infinite_values(analyzer, threshold: float = _THRESHOLDS.infinite_rat return issues -def _check_constant_length(analyzer, threshold: float = _THRESHOLDS.constant_length_ratio): - """Detect string columns where all values have the same length (e.g., IDs, codes).""" +def _check_constant_length(analyzer): + _cfg = analyzer.config.outliers issues = [] for col in analyzer.df.select_dtypes(include="object").columns: series = analyzer.df[col].dropna().astype(str) - if len(series) < _THRESHOLDS.min_sample_size: + if len(series) < _cfg.min_sample_size: continue lengths = series.str.len() most_common_length_ratio = lengths.value_counts(normalize=True).iloc[0] if len(lengths) > 0 else 0 - if most_common_length_ratio >= threshold: + if most_common_length_ratio >= _cfg.constant_length_ratio: most_common_length = int(lengths.mode().iloc[0]) issues.append( Issue( diff --git a/hashprep/checks/statistical_tests.py b/hashprep/checks/statistical_tests.py index 04c42e4..059cdd6 100644 --- a/hashprep/checks/statistical_tests.py +++ b/hashprep/checks/statistical_tests.py @@ -6,19 +6,16 @@ import numpy as np from scipy.stats import levene, normaltest, shapiro -from ..config import DEFAULT_CONFIG from .core import Issue -_ST = DEFAULT_CONFIG.statistical_tests - -def _run_normality_test(series) -> tuple[str, float, float]: +def _run_normality_test(series, shapiro_max_n: int) -> tuple[str, float, float]: """ Return (test_name, statistic, p_value) for the most appropriate normality test. Uses Shapiro-Wilk for n <= shapiro_max_n, D'Agostino-Pearson otherwise. """ n = len(series) - if n <= _ST.shapiro_max_n: + if n <= shapiro_max_n: stat, p = shapiro(series) return "shapiro_wilk", float(stat), float(p) else: @@ -32,19 +29,20 @@ def _check_normality(analyzer) -> list[Issue]: Uses Shapiro-Wilk for n <= 5000, D'Agostino-Pearson for larger samples. Non-normality matters for linear models, t-tests, and certain imputation strategies. """ + _cfg = analyzer.config.statistical_tests issues = [] for col in analyzer.df.select_dtypes(include="number").columns: series = analyzer.df[col].dropna() n = len(series) - if n < _ST.normality_min_n: + if n < _cfg.normality_min_n: continue if series.nunique() <= 1: continue - test_name, stat, p_val = _run_normality_test(series) + test_name, stat, p_val = _run_normality_test(series, _cfg.shapiro_max_n) - if p_val < _ST.normality_p_value: + if p_val < _cfg.normality_p_value: # Severity: very small p → critical (strong evidence), otherwise warning severity = "critical" if p_val < 0.001 else "warning" impact = "high" if severity == "critical" else "medium" @@ -80,6 +78,7 @@ def _check_variance_homogeneity(analyzer) -> list[Issue]: Only runs when a target column is set and has at least 2 groups with sufficient data. """ + _cfg = analyzer.config.statistical_tests issues = [] if analyzer.target_col is None: @@ -99,7 +98,7 @@ def _check_variance_homogeneity(analyzer) -> list[Issue]: for label in groups_labels: mask = analyzer.df[analyzer.target_col] == label grp = series[mask].dropna().values - if len(grp) >= _ST.levene_min_group_size: + if len(grp) >= _cfg.levene_min_group_size: groups.append(grp) if len(groups) < 2: @@ -110,7 +109,7 @@ def _check_variance_homogeneity(analyzer) -> list[Issue]: except ValueError: continue - if p_val < _ST.levene_p_value: + if p_val < _cfg.levene_p_value: # Compute per-group stds to add colour to the description stds = [float(np.std(g, ddof=1)) for g in groups] std_ratio = max(stds) / min(stds) if min(stds) > 0 else float("inf") diff --git a/hashprep/config.py b/hashprep/config.py index 805eb3b..3bd4c84 100644 --- a/hashprep/config.py +++ b/hashprep/config.py @@ -7,6 +7,7 @@ """ from dataclasses import dataclass, field +from dataclasses import fields as _fields @dataclass(frozen=True) @@ -230,3 +231,27 @@ class HashPrepConfig: # Global default config instance DEFAULT_CONFIG = HashPrepConfig() + + +def config_from_dict(d: dict) -> "HashPrepConfig": + """Build a HashPrepConfig from a (possibly partial) nested dict. + + Unknown keys are silently ignored; missing keys fall back to defaults. + """ + default = HashPrepConfig() + + def _merge(cls, default_obj, overrides: dict): + kwargs = {} + for f in _fields(cls): + if f.name not in overrides: + kwargs[f.name] = getattr(default_obj, f.name) + else: + val = overrides[f.name] + field_default = getattr(default_obj, f.name) + if hasattr(field_default, "__dataclass_fields__") and isinstance(val, dict): + kwargs[f.name] = _merge(type(field_default), field_default, val) + else: + kwargs[f.name] = val + return cls(**kwargs) + + return _merge(HashPrepConfig, default, d) diff --git a/hashprep/core/analyzer.py b/hashprep/core/analyzer.py index a7fe3a2..0b704bb 100644 --- a/hashprep/core/analyzer.py +++ b/hashprep/core/analyzer.py @@ -6,6 +6,7 @@ from scipy.stats import ConstantInputWarning from ..checks import run_checks +from ..config import DEFAULT_CONFIG, HashPrepConfig from ..summaries import ( add_reproduction_info, get_dataset_preview, @@ -74,6 +75,7 @@ def __init__( comparison_df: pd.DataFrame | None = None, sampling_config: SamplingConfig | None = None, auto_sample: bool = True, + config: HashPrepConfig | None = None, ): if not isinstance(df, pd.DataFrame): raise TypeError(f"Expected pandas DataFrame, got {type(df).__name__}") @@ -84,6 +86,7 @@ def __init__( if comparison_df is not None and not isinstance(comparison_df, pd.DataFrame): raise TypeError(f"comparison_df must be a pandas DataFrame, got {type(comparison_df).__name__}") + self.config = config if config is not None else DEFAULT_CONFIG self.comparison_df = comparison_df self.target_col = target_col self.selected_checks = selected_checks diff --git a/hashprep/interfaces/cli/main.py b/hashprep/interfaces/cli/main.py index d5e216f..7959a15 100644 --- a/hashprep/interfaces/cli/main.py +++ b/hashprep/interfaces/cli/main.py @@ -13,6 +13,7 @@ from hashprep.preparers.pipeline_builder import PipelineBuilder from hashprep.preparers.suggestions import SuggestionProvider from hashprep.reports import generate_report +from hashprep.utils.config_loader import load_config from hashprep.utils.sampling import SamplingConfig @@ -70,7 +71,14 @@ def version(): help="Max rows for sampling (default: 100000)", ) @click.option("--no-sample", is_flag=True, help="Disable automatic sampling") -def scan(file_path, critical_only, quiet, json_out, target, checks, comparison, sample_size, no_sample): +@click.option( + "--config", + "config_path", + type=click.Path(exists=True), + default=None, + help="Path to config file (.yaml, .toml, .json)", +) +def scan(file_path, critical_only, quiet, json_out, target, checks, comparison, sample_size, no_sample, config_path): df = pd.read_csv(file_path) comparison_df = pd.read_csv(comparison) if comparison else None @@ -90,6 +98,7 @@ def scan(file_path, critical_only, quiet, json_out, target, checks, comparison, if not no_sample and sample_size: sampling_config = SamplingConfig(max_rows=sample_size) + config = load_config(config_path) if config_path else None analyzer = DatasetAnalyzer( df, target_col=target, @@ -97,6 +106,7 @@ def scan(file_path, critical_only, quiet, json_out, target, checks, comparison, comparison_df=comparison_df, sampling_config=sampling_config, auto_sample=not no_sample, + config=config, ) summary = analyzer.analyze() @@ -165,7 +175,14 @@ def scan(file_path, critical_only, quiet, json_out, target, checks, comparison, help="Max rows for sampling (default: 100000)", ) @click.option("--no-sample", is_flag=True, help="Disable automatic sampling") -def details(file_path, target, checks, comparison, sample_size, no_sample): +@click.option( + "--config", + "config_path", + type=click.Path(exists=True), + default=None, + help="Path to config file (.yaml, .toml, .json)", +) +def details(file_path, target, checks, comparison, sample_size, no_sample, config_path): df = pd.read_csv(file_path) comparison_df = pd.read_csv(comparison) if comparison else None @@ -185,6 +202,7 @@ def details(file_path, target, checks, comparison, sample_size, no_sample): if not no_sample and sample_size: sampling_config = SamplingConfig(max_rows=sample_size) + config = load_config(config_path) if config_path else None analyzer = DatasetAnalyzer( df, target_col=target, @@ -192,6 +210,7 @@ def details(file_path, target, checks, comparison, sample_size, no_sample): comparison_df=comparison_df, sampling_config=sampling_config, auto_sample=not no_sample, + config=config, ) summary = analyzer.analyze() @@ -276,6 +295,13 @@ def details(file_path, target, checks, comparison, sample_size, no_sample): help="Max rows for sampling (default: 100000)", ) @click.option("--no-sample", is_flag=True, help="Disable automatic sampling") +@click.option( + "--config", + "config_path", + type=click.Path(exists=True), + default=None, + help="Path to config file (.yaml, .toml, .json)", +) def report( file_path, with_code, @@ -288,6 +314,7 @@ def report( comparison, sample_size, no_sample, + config_path, ): df = pd.read_csv(file_path) comparison_df = pd.read_csv(comparison) if comparison else None @@ -308,6 +335,7 @@ def report( if not no_sample and sample_size: sampling_config = SamplingConfig(max_rows=sample_size) + config = load_config(config_path) if config_path else None analyzer = DatasetAnalyzer( df, target_col=target, @@ -316,6 +344,7 @@ def report( comparison_df=comparison_df, sampling_config=sampling_config, auto_sample=not no_sample, + config=config, ) summary = analyzer.analyze() diff --git a/hashprep/utils/config_loader.py b/hashprep/utils/config_loader.py new file mode 100644 index 0000000..53b56ac --- /dev/null +++ b/hashprep/utils/config_loader.py @@ -0,0 +1,50 @@ +"""Load HashPrepConfig from YAML, TOML, or JSON files.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from ..config import HashPrepConfig, config_from_dict + + +def load_config(path: str | Path) -> HashPrepConfig: + """Load a HashPrepConfig from a YAML (.yaml/.yml), TOML (.toml), or JSON (.json) file. + + Only keys present in the file are overridden; all others fall back to defaults. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {path}") + + suffix = path.suffix.lower() + + if suffix in (".yaml", ".yml"): + try: + import yaml + except ImportError as e: + raise ImportError("pyyaml is required for YAML config files: pip install pyyaml") from e + with open(path) as f: + raw = yaml.safe_load(f) or {} + elif suffix == ".toml": + try: + import tomllib + except ImportError: + try: + import tomli as tomllib # type: ignore[no-redef] + except ImportError as e: + raise ImportError( + "tomllib (Python 3.11+) or tomli is required for TOML config files: pip install tomli" + ) from e + with open(path, "rb") as f: + raw = tomllib.load(f) + elif suffix == ".json": + with open(path) as f: + raw = json.load(f) + else: + raise ValueError(f"Unsupported config file format: {suffix!r}. Use .yaml, .yml, .toml, or .json") + + if not isinstance(raw, dict): + raise ValueError(f"Config file must contain a mapping at the top level, got {type(raw).__name__}") + + return config_from_dict(raw) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py new file mode 100644 index 0000000..86eecd1 --- /dev/null +++ b/tests/test_config_loader.py @@ -0,0 +1,217 @@ +"""Tests for config file loading and config_from_dict.""" + +import json +import textwrap + +import numpy as np +import pandas as pd +import pytest + +from hashprep import DatasetAnalyzer, HashPrepConfig, load_config +from hashprep.config import DEFAULT_CONFIG, config_from_dict + +# --------------------------------------------------------------------------- +# config_from_dict +# --------------------------------------------------------------------------- + + +class TestConfigFromDict: + def test_empty_dict_returns_defaults(self): + cfg = config_from_dict({}) + assert cfg == DEFAULT_CONFIG + + def test_partial_override_missing_values(self): + cfg = config_from_dict({"missing_values": {"warning": 0.1}}) + assert cfg.missing_values.warning == 0.1 + # Other fields stay at default + assert cfg.missing_values.critical == DEFAULT_CONFIG.missing_values.critical + + def test_partial_override_outliers(self): + cfg = config_from_dict({"outliers": {"z_score": 3.0}}) + assert cfg.outliers.z_score == 3.0 + assert cfg.outliers.outlier_ratio_critical == DEFAULT_CONFIG.outliers.outlier_ratio_critical + + def test_multiple_section_override(self): + cfg = config_from_dict( + { + "missing_values": {"warning": 0.2}, + "outliers": {"z_score": 2.5}, + } + ) + assert cfg.missing_values.warning == 0.2 + assert cfg.outliers.z_score == 2.5 + # Unmodified sections stay at default + assert cfg.correlations == DEFAULT_CONFIG.correlations + + def test_unknown_keys_are_ignored(self): + # Should not raise + cfg = config_from_dict({"nonexistent_section": {"foo": 1}}) + assert cfg == DEFAULT_CONFIG + + def test_unknown_nested_keys_are_ignored(self): + cfg = config_from_dict({"outliers": {"z_score": 3.0, "nonexistent": 99}}) + assert cfg.outliers.z_score == 3.0 + + def test_returns_hashprepconfig_instance(self): + cfg = config_from_dict({}) + assert isinstance(cfg, HashPrepConfig) + + def test_result_is_frozen(self): + cfg = config_from_dict({"outliers": {"z_score": 2.0}}) + with pytest.raises((AttributeError, TypeError)): + cfg.outliers = None # type: ignore[assignment] + + def test_int_override(self): + cfg = config_from_dict({"statistical_tests": {"shapiro_max_n": 1000}}) + assert cfg.statistical_tests.shapiro_max_n == 1000 + + def test_float_override(self): + cfg = config_from_dict({"correlations": {"spearman_warning": 0.8}}) + assert cfg.correlations.spearman_warning == 0.8 + + +# --------------------------------------------------------------------------- +# load_config — YAML +# --------------------------------------------------------------------------- + + +class TestLoadConfigYaml: + def test_load_minimal_yaml(self, tmp_path): + yaml_content = textwrap.dedent("""\ + missing_values: + warning: 0.3 + """) + cfg_file = tmp_path / "config.yaml" + cfg_file.write_text(yaml_content) + + cfg = load_config(cfg_file) + assert cfg.missing_values.warning == 0.3 + + def test_load_yml_extension(self, tmp_path): + cfg_file = tmp_path / "config.yml" + cfg_file.write_text("outliers:\n z_score: 3.5\n") + cfg = load_config(cfg_file) + assert cfg.outliers.z_score == 3.5 + + def test_empty_yaml_returns_defaults(self, tmp_path): + cfg_file = tmp_path / "config.yaml" + cfg_file.write_text("") + cfg = load_config(cfg_file) + assert cfg == DEFAULT_CONFIG + + def test_load_multi_section_yaml(self, tmp_path): + yaml_content = textwrap.dedent("""\ + outliers: + z_score: 3.0 + skewness_warning: 2.0 + correlations: + spearman_warning: 0.8 + """) + cfg_file = tmp_path / "config.yaml" + cfg_file.write_text(yaml_content) + cfg = load_config(cfg_file) + assert cfg.outliers.z_score == 3.0 + assert cfg.outliers.skewness_warning == 2.0 + assert cfg.correlations.spearman_warning == 0.8 + + +# --------------------------------------------------------------------------- +# load_config — JSON +# --------------------------------------------------------------------------- + + +class TestLoadConfigJson: + def test_load_json(self, tmp_path): + data = {"missing_values": {"warning": 0.25, "critical": 0.6}} + cfg_file = tmp_path / "config.json" + cfg_file.write_text(json.dumps(data)) + cfg = load_config(cfg_file) + assert cfg.missing_values.warning == 0.25 + assert cfg.missing_values.critical == 0.6 + + def test_empty_json_object_returns_defaults(self, tmp_path): + cfg_file = tmp_path / "config.json" + cfg_file.write_text("{}") + cfg = load_config(cfg_file) + assert cfg == DEFAULT_CONFIG + + +# --------------------------------------------------------------------------- +# load_config — error cases +# --------------------------------------------------------------------------- + + +class TestLoadConfigErrors: + def test_file_not_found(self): + with pytest.raises(FileNotFoundError): + load_config("/nonexistent/path/config.yaml") + + def test_unsupported_extension(self, tmp_path): + cfg_file = tmp_path / "config.ini" + cfg_file.write_text("[section]\nkey = value\n") + with pytest.raises(ValueError, match="Unsupported config file format"): + load_config(cfg_file) + + +# --------------------------------------------------------------------------- +# DatasetAnalyzer — custom config respected by checks +# --------------------------------------------------------------------------- + + +rng = np.random.default_rng(0) + + +class TestAnalyzerCustomConfig: + def test_default_config_is_set(self): + df = pd.DataFrame({"x": rng.standard_normal(100)}) + analyzer = DatasetAnalyzer(df, auto_sample=False) + assert analyzer.config == DEFAULT_CONFIG + + def test_custom_config_stored(self): + df = pd.DataFrame({"x": rng.standard_normal(100)}) + custom = config_from_dict({"outliers": {"z_score": 2.0}}) + analyzer = DatasetAnalyzer(df, auto_sample=False, config=custom) + assert analyzer.config.outliers.z_score == 2.0 + + def test_high_missing_threshold_suppresses_issue(self): + # Column has 50% missing — normally a warning, but with warning=0.9 it should be silent + data = [1.0] * 50 + [float("nan")] * 50 + df = pd.DataFrame({"x": data}) + custom = config_from_dict({"missing_values": {"warning": 0.9, "critical": 0.95}}) + analyzer = DatasetAnalyzer(df, auto_sample=False, config=custom, selected_checks=["high_missing_values"]) + summary = analyzer.analyze() + categories = [i["category"] for i in summary["issues"]] + assert "missing_values" not in categories + + def test_low_missing_threshold_triggers_issue(self): + # Column has 10% missing — default threshold is 0.4 (warning), but with 0.05 it should fire + data = [1.0] * 90 + [float("nan")] * 10 + df = pd.DataFrame({"x": data}) + custom = config_from_dict({"missing_values": {"warning": 0.05, "critical": 0.5}}) + analyzer = DatasetAnalyzer(df, auto_sample=False, config=custom, selected_checks=["high_missing_values"]) + summary = analyzer.analyze() + categories = [i["category"] for i in summary["issues"]] + assert "missing_values" in categories + + def test_skewness_threshold_respected(self): + # Highly skewed data — lower the threshold to catch it with warning=0.5 + data = [1.0] * 90 + [100.0] * 10 + df = pd.DataFrame({"x": data}) + custom = config_from_dict({"outliers": {"skewness_warning": 0.5, "skewness_critical": 20.0}}) + analyzer = DatasetAnalyzer(df, auto_sample=False, config=custom, selected_checks=["skewness"]) + summary = analyzer.analyze() + categories = [i["category"] for i in summary["issues"]] + assert "skewness" in categories + + def test_load_config_from_yaml_used_in_analyzer(self, tmp_path): + yaml_content = "missing_values:\n warning: 0.05\n critical: 0.5\n" + cfg_file = tmp_path / "custom.yaml" + cfg_file.write_text(yaml_content) + + data = [1.0] * 90 + [float("nan")] * 10 + df = pd.DataFrame({"x": data}) + cfg = load_config(cfg_file) + analyzer = DatasetAnalyzer(df, auto_sample=False, config=cfg, selected_checks=["high_missing_values"]) + summary = analyzer.analyze() + categories = [i["category"] for i in summary["issues"]] + assert "missing_values" in categories diff --git a/tests/test_datetime.py b/tests/test_datetime.py index 67c6c73..e282b07 100644 --- a/tests/test_datetime.py +++ b/tests/test_datetime.py @@ -67,8 +67,11 @@ class _FakeAnalyzer: """Minimal stand-in for DatasetAnalyzer used in unit tests.""" def __init__(self, df, column_types): + from hashprep.config import DEFAULT_CONFIG + self.df = df self.column_types = column_types + self.config = DEFAULT_CONFIG class TestFutureDatesCheck: diff --git a/tests/test_mutual_info.py b/tests/test_mutual_info.py index 0aa0e1c..e864d61 100644 --- a/tests/test_mutual_info.py +++ b/tests/test_mutual_info.py @@ -20,9 +20,12 @@ class _FakeAnalyzer: def __init__(self, df, target_col=None): + from hashprep.config import DEFAULT_CONFIG + self.df = df self.target_col = target_col self.column_types = infer_types(df) + self.config = DEFAULT_CONFIG # --------------------------------------------------------------------------- diff --git a/tests/test_statistical_tests.py b/tests/test_statistical_tests.py index 8ae64d5..635c238 100644 --- a/tests/test_statistical_tests.py +++ b/tests/test_statistical_tests.py @@ -15,11 +15,13 @@ class _FakeAnalyzer: def __init__(self, df, target_col=None): - self.df = df - self.target_col = target_col + from hashprep.config import DEFAULT_CONFIG from hashprep.utils.type_inference import infer_types + self.df = df + self.target_col = target_col self.column_types = infer_types(df) + self.config = DEFAULT_CONFIG rng = np.random.default_rng(42)