diff --git a/changelog.d/qrf-convenience-features.added.md b/changelog.d/qrf-convenience-features.added.md new file mode 100644 index 0000000..99d8a1f --- /dev/null +++ b/changelog.d/qrf-convenience-features.added.md @@ -0,0 +1 @@ +Added `max_train_samples` parameter and `fit_predict()` method to QRF, with automatic zero-filling of missing output variables. diff --git a/microimpute/models/qrf.py b/microimpute/models/qrf.py index b1bbecf..b7617e3 100644 --- a/microimpute/models/qrf.py +++ b/microimpute/models/qrf.py @@ -492,6 +492,7 @@ def __init__( memory_efficient: bool = False, batch_size: Optional[int] = None, cleanup_interval: int = 10, + max_train_samples: Optional[int] = None, ) -> None: """Initialize the QRF model. @@ -500,6 +501,9 @@ def __init__( memory_efficient: Enable memory optimization features. batch_size: Process variables in batches to reduce memory usage. cleanup_interval: Frequency of garbage collection (every N variables). + max_train_samples: If set, subsample X_train to at most this many + rows before fitting. Reduces memory and training time while + preserving sequential covariance structure. """ super().__init__(log_level=log_level) self.models = {} @@ -507,6 +511,9 @@ def __init__( self.memory_efficient = memory_efficient self.batch_size = batch_size self.cleanup_interval = cleanup_interval + if max_train_samples is not None and max_train_samples < 1: + raise ValueError("max_train_samples must be a positive integer") + self.max_train_samples = max_train_samples self.logger.debug("Initializing QRF imputer") @@ -675,6 +682,20 @@ def _fit( RuntimeError: If model fitting fails. """ try: + # Subsample training data if max_train_samples is set + if ( + self.max_train_samples is not None + and len(X_train) > self.max_train_samples + ): + self.logger.info( + f"Subsampling training data from " + f"{len(X_train)} to {self.max_train_samples} rows" + ) + X_train = X_train.sample( + n=self.max_train_samples, + random_state=self.seed, + ).reset_index(drop=True) + # Store target type information early for hyperparameter tuning self.categorical_targets = categorical_targets or {} self.boolean_targets = boolean_targets or {} @@ -1081,6 +1102,60 @@ def _fit_variable_batch( f" Memory cleanup performed. Usage: {self._get_memory_usage_info()}" ) + def fit_predict( + self, + X_train: pd.DataFrame, + X_test: pd.DataFrame, + predictors: List[str], + imputed_variables: List[str], + **kwargs: Any, + ) -> pd.DataFrame: + """Fit the model and immediately predict, then release the fitted model. + + Convenience method that combines fit() + predict() + cleanup. + Useful when you don't need to keep the fitted model around. + + Variables in ``imputed_variables`` that are missing from ``X_train`` + are automatically skipped during fitting and zero-filled in the + output, so callers don't need to pre-filter. + + Args: + X_train: DataFrame containing the training data. + X_test: DataFrame containing the test data (predictors only). + predictors: List of column names to use as predictors. + imputed_variables: List of column names to impute. + **kwargs: Additional keyword arguments passed to fit(). + + Returns: + DataFrame with one column per imputed variable. + """ + missing = [v for v in imputed_variables if v not in X_train.columns] + if missing: + self.logger.warning( + f"fit_predict: {len(missing)} variables not in X_train " + f"and will be zero-filled: {missing}" + ) + + fitted = self.fit( + X_train=X_train, + predictors=predictors, + imputed_variables=imputed_variables, + skip_missing=True, + **kwargs, + ) + + result = fitted.predict(X_test=X_test[predictors]) + del fitted + gc.collect() + + # Zero-fill missing variables to match the requested output shape. + for var in missing: + result[var] = 0 + + # Reorder columns to match the original requested order. + result = result[[v for v in imputed_variables if v in result.columns]] + return result + def _tune_qrf_hyperparameters( self, data: pd.DataFrame, diff --git a/tests/test_models/test_qrf.py b/tests/test_models/test_qrf.py index 695f8ad..c97868a 100644 --- a/tests/test_models/test_qrf.py +++ b/tests/test_models/test_qrf.py @@ -1257,3 +1257,183 @@ def test_qrf_hyperparameter_tuning_improves_performance() -> None: assert tuned_loss <= untuned_loss * (1 + margin), ( f"Tuned loss ({tuned_loss:.4f}) should be ≤ {(1 + margin) * 100}% of untuned loss ({untuned_loss:.4f} * {1 + margin} = {untuned_loss * (1 + margin):.4f})" ) + + +# === max_train_samples Tests === + + +def test_qrf_max_train_samples_subsamples() -> None: + """Test that max_train_samples reduces training data size.""" + np.random.seed(42) + n_samples = 500 + + data = pd.DataFrame( + { + "x1": np.random.randn(n_samples), + "x2": np.random.randn(n_samples), + "y": np.random.randn(n_samples), + } + ) + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.INFO) + + model = QRF(log_level="INFO", max_train_samples=100) + model.logger.addHandler(handler) + + fitted = model.fit( + data, + predictors=["x1", "x2"], + imputed_variables=["y"], + n_estimators=10, + ) + + log_output = log_stream.getvalue() + assert "Subsampling training data from 500 to 100 rows" in log_output + + # Predictions should still work + test_data = data[["x1", "x2"]].head(10) + predictions = fitted.predict(test_data) + assert isinstance(predictions, pd.DataFrame) + assert len(predictions) == 10 + + model.logger.removeHandler(handler) + + +def test_qrf_max_train_samples_no_op_when_small() -> None: + """max_train_samples should be a no-op when data is already small.""" + np.random.seed(42) + n_samples = 50 + + data = pd.DataFrame( + { + "x1": np.random.randn(n_samples), + "y": np.random.randn(n_samples), + } + ) + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.INFO) + + model = QRF(log_level="INFO", max_train_samples=100) + model.logger.addHandler(handler) + + model.fit( + data, + predictors=["x1"], + imputed_variables=["y"], + n_estimators=10, + ) + + log_output = log_stream.getvalue() + assert "Subsampling" not in log_output + + model.logger.removeHandler(handler) + + +# === fit_predict Tests === + + +def test_qrf_fit_predict_basic() -> None: + """fit_predict should return the same shape as fit + predict.""" + np.random.seed(42) + n_train, n_test = 200, 50 + + train = pd.DataFrame( + { + "x": np.random.randn(n_train), + "y1": np.random.randn(n_train), + "y2": np.random.randn(n_train), + } + ) + test = pd.DataFrame({"x": np.random.randn(n_test)}) + + model = QRF(log_level="WARNING") + result = model.fit_predict( + X_train=train, + X_test=test, + predictors=["x"], + imputed_variables=["y1", "y2"], + n_estimators=10, + ) + + assert isinstance(result, pd.DataFrame) + assert result.shape == (n_test, 2) + assert list(result.columns) == ["y1", "y2"] + assert not result.isna().any().any() + + +def test_qrf_fit_predict_missing_variables() -> None: + """fit_predict should zero-fill variables missing from X_train.""" + np.random.seed(42) + n_train, n_test = 200, 50 + + train = pd.DataFrame( + { + "x": np.random.randn(n_train), + "y_present": np.random.randn(n_train), + } + ) + test = pd.DataFrame({"x": np.random.randn(n_test)}) + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setLevel(logging.WARNING) + + model = QRF(log_level="WARNING") + model.logger.addHandler(handler) + + result = model.fit_predict( + X_train=train, + X_test=test, + predictors=["x"], + imputed_variables=["y_present", "y_missing1", "y_missing2"], + n_estimators=10, + ) + + log_output = log_stream.getvalue() + assert "y_missing1" in log_output + assert "y_missing2" in log_output + assert "zero-filled" in log_output + + # Output should have all three columns + assert list(result.columns) == [ + "y_present", + "y_missing1", + "y_missing2", + ] + # Present variable should have non-zero values + assert result["y_present"].abs().sum() > 0 + # Missing variables should be zero + assert (result["y_missing1"] == 0).all() + assert (result["y_missing2"] == 0).all() + + model.logger.removeHandler(handler) + + +def test_qrf_fit_predict_with_max_train_samples() -> None: + """fit_predict should work together with max_train_samples.""" + np.random.seed(42) + n_train, n_test = 500, 30 + + train = pd.DataFrame( + { + "x": np.random.randn(n_train), + "y": np.random.randn(n_train), + } + ) + test = pd.DataFrame({"x": np.random.randn(n_test)}) + + model = QRF(log_level="WARNING", max_train_samples=100) + result = model.fit_predict( + X_train=train, + X_test=test, + predictors=["x"], + imputed_variables=["y"], + n_estimators=10, + ) + + assert result.shape == (n_test, 1) + assert not result.isna().any().any()