Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/qrf-convenience-features.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `max_train_samples` parameter and `fit_predict()` method to QRF, with automatic zero-filling of missing output variables.
75 changes: 75 additions & 0 deletions microimpute/models/qrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def __init__(
memory_efficient: bool = False,
batch_size: Optional[int] = None,
cleanup_interval: int = 10,
max_train_samples: Optional[int] = None,
) -> None:
"""Initialize the QRF model.

Expand All @@ -500,13 +501,19 @@ def __init__(
memory_efficient: Enable memory optimization features.
batch_size: Process variables in batches to reduce memory usage.
cleanup_interval: Frequency of garbage collection (every N variables).
max_train_samples: If set, subsample X_train to at most this many
rows before fitting. Reduces memory and training time while
preserving sequential covariance structure.
"""
super().__init__(log_level=log_level)
self.models = {}
self.log_level = log_level
self.memory_efficient = memory_efficient
self.batch_size = batch_size
self.cleanup_interval = cleanup_interval
if max_train_samples is not None and max_train_samples < 1:
raise ValueError("max_train_samples must be a positive integer")
self.max_train_samples = max_train_samples

self.logger.debug("Initializing QRF imputer")

Expand Down Expand Up @@ -675,6 +682,20 @@ def _fit(
RuntimeError: If model fitting fails.
"""
try:
# Subsample training data if max_train_samples is set
if (
self.max_train_samples is not None
and len(X_train) > self.max_train_samples
):
self.logger.info(
f"Subsampling training data from "
f"{len(X_train)} to {self.max_train_samples} rows"
)
X_train = X_train.sample(
n=self.max_train_samples,
random_state=self.seed,
).reset_index(drop=True)

# Store target type information early for hyperparameter tuning
self.categorical_targets = categorical_targets or {}
self.boolean_targets = boolean_targets or {}
Expand Down Expand Up @@ -1081,6 +1102,60 @@ def _fit_variable_batch(
f" Memory cleanup performed. Usage: {self._get_memory_usage_info()}"
)

def fit_predict(
self,
X_train: pd.DataFrame,
X_test: pd.DataFrame,
predictors: List[str],
imputed_variables: List[str],
**kwargs: Any,
) -> pd.DataFrame:
"""Fit the model and immediately predict, then release the fitted model.

Convenience method that combines fit() + predict() + cleanup.
Useful when you don't need to keep the fitted model around.

Variables in ``imputed_variables`` that are missing from ``X_train``
are automatically skipped during fitting and zero-filled in the
output, so callers don't need to pre-filter.

Args:
X_train: DataFrame containing the training data.
X_test: DataFrame containing the test data (predictors only).
predictors: List of column names to use as predictors.
imputed_variables: List of column names to impute.
**kwargs: Additional keyword arguments passed to fit().

Returns:
DataFrame with one column per imputed variable.
"""
missing = [v for v in imputed_variables if v not in X_train.columns]
if missing:
self.logger.warning(
f"fit_predict: {len(missing)} variables not in X_train "
f"and will be zero-filled: {missing}"
)

fitted = self.fit(
X_train=X_train,
predictors=predictors,
imputed_variables=imputed_variables,
skip_missing=True,
**kwargs,
)

result = fitted.predict(X_test=X_test[predictors])
del fitted
gc.collect()

# Zero-fill missing variables to match the requested output shape.
for var in missing:
result[var] = 0

# Reorder columns to match the original requested order.
result = result[[v for v in imputed_variables if v in result.columns]]
return result

def _tune_qrf_hyperparameters(
self,
data: pd.DataFrame,
Expand Down
180 changes: 180 additions & 0 deletions tests/test_models/test_qrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,3 +1257,183 @@ def test_qrf_hyperparameter_tuning_improves_performance() -> None:
assert tuned_loss <= untuned_loss * (1 + margin), (
f"Tuned loss ({tuned_loss:.4f}) should be ≤ {(1 + margin) * 100}% of untuned loss ({untuned_loss:.4f} * {1 + margin} = {untuned_loss * (1 + margin):.4f})"
)


# === max_train_samples Tests ===


def test_qrf_max_train_samples_subsamples() -> None:
"""Test that max_train_samples reduces training data size."""
np.random.seed(42)
n_samples = 500

data = pd.DataFrame(
{
"x1": np.random.randn(n_samples),
"x2": np.random.randn(n_samples),
"y": np.random.randn(n_samples),
}
)

log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setLevel(logging.INFO)

model = QRF(log_level="INFO", max_train_samples=100)
model.logger.addHandler(handler)

fitted = model.fit(
data,
predictors=["x1", "x2"],
imputed_variables=["y"],
n_estimators=10,
)

log_output = log_stream.getvalue()
assert "Subsampling training data from 500 to 100 rows" in log_output

# Predictions should still work
test_data = data[["x1", "x2"]].head(10)
predictions = fitted.predict(test_data)
assert isinstance(predictions, pd.DataFrame)
assert len(predictions) == 10

model.logger.removeHandler(handler)


def test_qrf_max_train_samples_no_op_when_small() -> None:
"""max_train_samples should be a no-op when data is already small."""
np.random.seed(42)
n_samples = 50

data = pd.DataFrame(
{
"x1": np.random.randn(n_samples),
"y": np.random.randn(n_samples),
}
)

log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setLevel(logging.INFO)

model = QRF(log_level="INFO", max_train_samples=100)
model.logger.addHandler(handler)

model.fit(
data,
predictors=["x1"],
imputed_variables=["y"],
n_estimators=10,
)

log_output = log_stream.getvalue()
assert "Subsampling" not in log_output

model.logger.removeHandler(handler)


# === fit_predict Tests ===


def test_qrf_fit_predict_basic() -> None:
"""fit_predict should return the same shape as fit + predict."""
np.random.seed(42)
n_train, n_test = 200, 50

train = pd.DataFrame(
{
"x": np.random.randn(n_train),
"y1": np.random.randn(n_train),
"y2": np.random.randn(n_train),
}
)
test = pd.DataFrame({"x": np.random.randn(n_test)})

model = QRF(log_level="WARNING")
result = model.fit_predict(
X_train=train,
X_test=test,
predictors=["x"],
imputed_variables=["y1", "y2"],
n_estimators=10,
)

assert isinstance(result, pd.DataFrame)
assert result.shape == (n_test, 2)
assert list(result.columns) == ["y1", "y2"]
assert not result.isna().any().any()


def test_qrf_fit_predict_missing_variables() -> None:
"""fit_predict should zero-fill variables missing from X_train."""
np.random.seed(42)
n_train, n_test = 200, 50

train = pd.DataFrame(
{
"x": np.random.randn(n_train),
"y_present": np.random.randn(n_train),
}
)
test = pd.DataFrame({"x": np.random.randn(n_test)})

log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setLevel(logging.WARNING)

model = QRF(log_level="WARNING")
model.logger.addHandler(handler)

result = model.fit_predict(
X_train=train,
X_test=test,
predictors=["x"],
imputed_variables=["y_present", "y_missing1", "y_missing2"],
n_estimators=10,
)

log_output = log_stream.getvalue()
assert "y_missing1" in log_output
assert "y_missing2" in log_output
assert "zero-filled" in log_output

# Output should have all three columns
assert list(result.columns) == [
"y_present",
"y_missing1",
"y_missing2",
]
# Present variable should have non-zero values
assert result["y_present"].abs().sum() > 0
# Missing variables should be zero
assert (result["y_missing1"] == 0).all()
assert (result["y_missing2"] == 0).all()

model.logger.removeHandler(handler)


def test_qrf_fit_predict_with_max_train_samples() -> None:
"""fit_predict should work together with max_train_samples."""
np.random.seed(42)
n_train, n_test = 500, 30

train = pd.DataFrame(
{
"x": np.random.randn(n_train),
"y": np.random.randn(n_train),
}
)
test = pd.DataFrame({"x": np.random.randn(n_test)})

model = QRF(log_level="WARNING", max_train_samples=100)
result = model.fit_predict(
X_train=train,
X_test=test,
predictors=["x"],
imputed_variables=["y"],
n_estimators=10,
)

assert result.shape == (n_test, 1)
assert not result.isna().any().any()