From e08231fb806c5c5104871f9df5aff75f28a36743 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 9 Mar 2026 11:23:29 -0400 Subject: [PATCH 1/2] Add max_train_samples, fit_predict, and missing variable handling to QRF Adds three convenience features that downstream consumers (policyengine-us-data, policyengine-uk-data) currently implement manually: 1. max_train_samples: auto-subsample training data to reduce memory while preserving sequential covariance (the correct fix for #96) 2. fit_predict(): combines fit + predict + gc cleanup in one call 3. fit_predict() zero-fills variables missing from X_train instead of erroring Fixes #169 Co-Authored-By: Claude Opus 4.6 --- changelog.d/qrf-convenience-features.added.md | 1 + microimpute/models/qrf.py | 75 ++++++++ tests/test_models/test_qrf.py | 180 ++++++++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 changelog.d/qrf-convenience-features.added.md diff --git a/changelog.d/qrf-convenience-features.added.md b/changelog.d/qrf-convenience-features.added.md new file mode 100644 index 0000000..99d8a1f --- /dev/null +++ b/changelog.d/qrf-convenience-features.added.md @@ -0,0 +1 @@ +Added `max_train_samples` parameter and `fit_predict()` method to QRF, with automatic zero-filling of missing output variables. diff --git a/microimpute/models/qrf.py b/microimpute/models/qrf.py index b1bbecf..5bd7a78 100644 --- a/microimpute/models/qrf.py +++ b/microimpute/models/qrf.py @@ -492,6 +492,7 @@ def __init__( memory_efficient: bool = False, batch_size: Optional[int] = None, cleanup_interval: int = 10, + max_train_samples: Optional[int] = None, ) -> None: """Initialize the QRF model. @@ -500,6 +501,9 @@ def __init__( memory_efficient: Enable memory optimization features. batch_size: Process variables in batches to reduce memory usage. cleanup_interval: Frequency of garbage collection (every N variables). + max_train_samples: If set, subsample X_train to at most this many + rows before fitting. Reduces memory and training time while + preserving sequential covariance structure. """ super().__init__(log_level=log_level) self.models = {} @@ -507,6 +511,7 @@ def __init__( self.memory_efficient = memory_efficient self.batch_size = batch_size self.cleanup_interval = cleanup_interval + self.max_train_samples = max_train_samples self.logger.debug("Initializing QRF imputer") @@ -675,6 +680,20 @@ def _fit( RuntimeError: If model fitting fails. """ try: + # Subsample training data if max_train_samples is set + if ( + self.max_train_samples is not None + and len(X_train) > self.max_train_samples + ): + self.logger.info( + f"Subsampling training data from " + f"{len(X_train)} to {self.max_train_samples} rows" + ) + X_train = X_train.sample( + n=self.max_train_samples, + random_state=self.seed, + ) + # Store target type information early for hyperparameter tuning self.categorical_targets = categorical_targets or {} self.boolean_targets = boolean_targets or {} @@ -1081,6 +1100,62 @@ def _fit_variable_batch( f" Memory cleanup performed. Usage: {self._get_memory_usage_info()}" ) + def fit_predict( + self, + X_train: pd.DataFrame, + X_test: pd.DataFrame, + predictors: List[str], + imputed_variables: List[str], + **kwargs: Any, + ) -> pd.DataFrame: + """Fit the model and immediately predict, then release the fitted model. + + Convenience method that combines fit() + predict() + cleanup. + Useful when you don't need to keep the fitted model around. + + Variables in ``imputed_variables`` that are missing from ``X_train`` + are automatically skipped during fitting and zero-filled in the + output, so callers don't need to pre-filter. + + Args: + X_train: DataFrame containing the training data. + X_test: DataFrame containing the test data (predictors only). + predictors: List of column names to use as predictors. + imputed_variables: List of column names to impute. + **kwargs: Additional keyword arguments passed to fit(). + + Returns: + DataFrame with one column per imputed variable. + """ + # Identify missing variables before fit (which would error). + available = [v for v in imputed_variables if v in X_train.columns] + missing = [v for v in imputed_variables if v not in X_train.columns] + if missing: + self.logger.warning( + f"fit_predict: {len(missing)} variables not in X_train " + f"and will be zero-filled: {missing}" + ) + + fitted = self.fit( + X_train=X_train, + predictors=predictors, + imputed_variables=available, + **kwargs, + ) + + test_predictors = [p for p in predictors if p in X_test.columns] + result = fitted.predict(X_test=X_test[test_predictors]) + del fitted + gc.collect() + + # Zero-fill missing variables to match the requested output shape. + for var in missing: + result[var] = 0 + + # Reorder columns to match the original requested order. + result = result[[v for v in imputed_variables if v in result.columns]] + return result + def _tune_qrf_hyperparameters( self, data: pd.DataFrame, diff --git a/tests/test_models/test_qrf.py b/tests/test_models/test_qrf.py index 695f8ad..c97868a 100644 --- a/tests/test_models/test_qrf.py +++ b/tests/test_models/test_qrf.py @@ -1257,3 +1257,183 @@ def test_qrf_hyperparameter_tuning_improves_performance() -> None: assert tuned_loss <= untuned_loss * (1 + margin), ( f"Tuned loss ({tuned_loss:.4f}) should be ≤ {(1 + margin) * 100}% of untuned loss ({untuned_loss:.4f} * {1 + margin} = {untuned_loss * (1 + margin):.4f})" ) + + +# === max_train_samples Tests === + + +def test_qrf_max_train_samples_subsamples() -> None: + """Test that max_train_samples reduces training data size.""" + np.random.seed(42) + n_samples = 500 + + data = pd.DataFrame( + { + "x1": np.random.randn(n_samples), + "x2": np.random.randn(n_samples), + "y": np.random.randn(n_samples), + } + ) + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.INFO) + + model = QRF(log_level="INFO", max_train_samples=100) + model.logger.addHandler(handler) + + fitted = model.fit( + data, + predictors=["x1", "x2"], + imputed_variables=["y"], + n_estimators=10, + ) + + log_output = log_stream.getvalue() + assert "Subsampling training data from 500 to 100 rows" in log_output + + # Predictions should still work + test_data = data[["x1", "x2"]].head(10) + predictions = fitted.predict(test_data) + assert isinstance(predictions, pd.DataFrame) + assert len(predictions) == 10 + + model.logger.removeHandler(handler) + + +def test_qrf_max_train_samples_no_op_when_small() -> None: + """max_train_samples should be a no-op when data is already small.""" + np.random.seed(42) + n_samples = 50 + + data = pd.DataFrame( + { + "x1": np.random.randn(n_samples), + "y": np.random.randn(n_samples), + } + ) + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.INFO) + + model = QRF(log_level="INFO", max_train_samples=100) + model.logger.addHandler(handler) + + model.fit( + data, + predictors=["x1"], + imputed_variables=["y"], + n_estimators=10, + ) + + log_output = log_stream.getvalue() + assert "Subsampling" not in log_output + + model.logger.removeHandler(handler) + + +# === fit_predict Tests === + + +def test_qrf_fit_predict_basic() -> None: + """fit_predict should return the same shape as fit + predict.""" + np.random.seed(42) + n_train, n_test = 200, 50 + + train = pd.DataFrame( + { + "x": np.random.randn(n_train), + "y1": np.random.randn(n_train), + "y2": np.random.randn(n_train), + } + ) + test = pd.DataFrame({"x": np.random.randn(n_test)}) + + model = QRF(log_level="WARNING") + result = model.fit_predict( + X_train=train, + X_test=test, + predictors=["x"], + imputed_variables=["y1", "y2"], + n_estimators=10, + ) + + assert isinstance(result, pd.DataFrame) + assert result.shape == (n_test, 2) + assert list(result.columns) == ["y1", "y2"] + assert not result.isna().any().any() + + +def test_qrf_fit_predict_missing_variables() -> None: + """fit_predict should zero-fill variables missing from X_train.""" + np.random.seed(42) + n_train, n_test = 200, 50 + + train = pd.DataFrame( + { + "x": np.random.randn(n_train), + "y_present": np.random.randn(n_train), + } + ) + test = pd.DataFrame({"x": np.random.randn(n_test)}) + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.WARNING) + + model = QRF(log_level="WARNING") + model.logger.addHandler(handler) + + result = model.fit_predict( + X_train=train, + X_test=test, + predictors=["x"], + imputed_variables=["y_present", "y_missing1", "y_missing2"], + n_estimators=10, + ) + + log_output = log_stream.getvalue() + assert "y_missing1" in log_output + assert "y_missing2" in log_output + assert "zero-filled" in log_output + + # Output should have all three columns + assert list(result.columns) == [ + "y_present", + "y_missing1", + "y_missing2", + ] + # Present variable should have non-zero values + assert result["y_present"].abs().sum() > 0 + # Missing variables should be zero + assert (result["y_missing1"] == 0).all() + assert (result["y_missing2"] == 0).all() + + model.logger.removeHandler(handler) + + +def test_qrf_fit_predict_with_max_train_samples() -> None: + """fit_predict should work together with max_train_samples.""" + np.random.seed(42) + n_train, n_test = 500, 30 + + train = pd.DataFrame( + { + "x": np.random.randn(n_train), + "y": np.random.randn(n_train), + } + ) + test = pd.DataFrame({"x": np.random.randn(n_test)}) + + model = QRF(log_level="WARNING", max_train_samples=100) + result = model.fit_predict( + X_train=train, + X_test=test, + predictors=["x"], + imputed_variables=["y"], + n_estimators=10, + ) + + assert result.shape == (n_test, 1) + assert not result.isna().any().any() From fcf1640be89a5dc573f179ef90aefa3bc5367001 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 9 Mar 2026 11:27:00 -0400 Subject: [PATCH 2/2] Fix index reset, use skip_missing, validate max_train_samples - Add .reset_index(drop=True) after subsampling to prevent index corruption during sequential imputation - Use skip_missing=True in fit_predict() instead of reimplementing _handle_missing_variables() logic - Validate max_train_samples is positive Co-Authored-By: Claude Opus 4.6 --- microimpute/models/qrf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/microimpute/models/qrf.py b/microimpute/models/qrf.py index 5bd7a78..b7617e3 100644 --- a/microimpute/models/qrf.py +++ b/microimpute/models/qrf.py @@ -511,6 +511,8 @@ def __init__( self.memory_efficient = memory_efficient self.batch_size = batch_size self.cleanup_interval = cleanup_interval + if max_train_samples is not None and max_train_samples < 1: + raise ValueError("max_train_samples must be a positive integer") self.max_train_samples = max_train_samples self.logger.debug("Initializing QRF imputer") @@ -692,7 +694,7 @@ def _fit( X_train = X_train.sample( n=self.max_train_samples, random_state=self.seed, - ) + ).reset_index(drop=True) # Store target type information early for hyperparameter tuning self.categorical_targets = categorical_targets or {} @@ -1127,8 +1129,6 @@ def fit_predict( Returns: DataFrame with one column per imputed variable. """ - # Identify missing variables before fit (which would error). - available = [v for v in imputed_variables if v in X_train.columns] missing = [v for v in imputed_variables if v not in X_train.columns] if missing: self.logger.warning( @@ -1139,12 +1139,12 @@ def fit_predict( fitted = self.fit( X_train=X_train, predictors=predictors, - imputed_variables=available, + imputed_variables=imputed_variables, + skip_missing=True, **kwargs, ) - test_predictors = [p for p in predictors if p in X_test.columns] - result = fitted.predict(X_test=X_test[test_predictors]) + result = fitted.predict(X_test=X_test[predictors]) del fitted gc.collect()