From c2ab36ebf05e9c4e80dfd988d51c9027999c7ecd Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Mon, 26 Jan 2026 13:42:19 +0200 Subject: [PATCH 1/5] Implemented ML Pipeline Continuous new table rows RunInference --- ...m_Inference_Python_Benchmarks_Dataflow.yml | 25 +- ...rks_Dataflow_Table_Row_Inference_Batch.txt | 40 +++ ...ks_Dataflow_Table_Row_Inference_Stream.txt | 42 +++ .test-infra/tools/refresh_looker_metrics.py | 3 +- .../apache_beam/examples/inference/README.md | 66 ++++ .../examples/inference/table_row_inference.py | 337 ++++++++++++++++++ .../inference/table_row_inference_batch.py | 325 +++++++++++++++++ .../inference/table_row_inference_test.py | 190 ++++++++++ .../inference/table_row_inference_utils.py | 297 +++++++++++++++ .../table_row_inference_requirements.txt | 22 ++ .../testing/benchmarks/inference/README.md | 15 + .../table_row_inference_benchmark.py | 110 ++++++ .../load_tests/dataflow_cost_benchmark.py | 154 ++++++-- .../testing/load_tests/load_test.py | 10 +- .../apache_beam/testing/test_pipeline.py | 14 + .../www/site/content/en/performance/_index.md | 2 + .../performance/tablerowinference/_index.md | 45 +++ .../tablerowinferencestreaming/_index.md | 43 +++ website/www/site/data/performance.yaml | 32 ++ 19 files changed, 1730 insertions(+), 42 deletions(-) create mode 100644 .github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Batch.txt create mode 100644 .github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Stream.txt create mode 100644 sdks/python/apache_beam/examples/inference/table_row_inference.py create mode 100644 sdks/python/apache_beam/examples/inference/table_row_inference_batch.py create mode 100644 sdks/python/apache_beam/examples/inference/table_row_inference_test.py create mode 100644 sdks/python/apache_beam/examples/inference/table_row_inference_utils.py create mode 100644 sdks/python/apache_beam/ml/inference/table_row_inference_requirements.txt create mode 100644 sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py create mode 100644 website/www/site/content/en/performance/tablerowinference/_index.md create mode 100644 website/www/site/content/en/performance/tablerowinferencestreaming/_index.md diff --git a/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml b/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml index ff7480c320af..9348d1a9a5c5 100644 --- a/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml +++ b/.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml @@ -92,6 +92,7 @@ jobs: ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Batch_DistilBert_Base_Uncased.txt ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_VLLM_Gemma_Batch.txt + ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Base.txt # The env variables are created and populated in the test-arguments-action as "_test_arguments_" - name: get current time run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV @@ -189,4 +190,26 @@ jobs: -Prunner=DataflowRunner \ -PpythonVersion=3.10 \ -PloadTest.requirementsTxtFile=apache_beam/ml/inference/torch_tests_requirements.txt \ - '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_5 }} --job_name=benchmark-tests-pytorch-imagenet-python-gpu-${{env.NOW_UTC}} --output=gs://temp-storage-for-end-to-end-tests/torch/result_resnet152_gpu-${{env.NOW_UTC}}.txt' \ No newline at end of file + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_5 }} --job_name=benchmark-tests-pytorch-imagenet-python-gpu-${{env.NOW_UTC}} --output=gs://temp-storage-for-end-to-end-tests/torch/result_resnet152_gpu-${{env.NOW_UTC}}.txt' + - name: run Table Row Inference Sklearn Batch + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 180 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.inference.table_row_inference_benchmark \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + -PloadTest.requirementsTxtFile=apache_beam/ml/inference/table_row_inference_requirements.txt \ + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_9 }} --autoscaling_algorithm=NONE --metrics_table=result_table_row_inference_batch --influx_measurement=result_table_row_inference_batch --mode=batch --input_file=gs://apache-beam-ml/testing/inputs/table_rows_100k_benchmark.jsonl --input_expand_factor=100 --output_table=apache-beam-testing:beam_run_inference.result_table_row_inference_batch_outputs --job_name=benchmark-tests-table-row-inference-batch-${{env.NOW_UTC}}' + - name: run Table Row Inference Sklearn Stream + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 180 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.inference.table_row_inference_benchmark \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + -PloadTest.requirementsTxtFile=apache_beam/ml/inference/table_row_inference_requirements.txt \ + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_9 }} --autoscaling_algorithm=THROUGHPUT_BASED --max_num_workers=20 --metrics_table=result_table_row_inference_stream --influx_measurement=result_table_row_inference_stream --mode=streaming --input_subscription=projects/apache-beam-testing/subscriptions/table_row_inference_benchmark --window_size_sec=60 --trigger_interval_sec=30 --timeout_ms=900000 --output_table=apache-beam-testing:beam_run_inference.result_table_row_inference_stream_outputs --job_name=benchmark-tests-table-row-inference-stream-${{env.NOW_UTC}}' diff --git a/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Batch.txt b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Batch.txt new file mode 100644 index 000000000000..36b3527dcbb2 --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Batch.txt @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +--project=apache-beam-testing +--region=us-central1 +--worker_machine_type=n1-standard-4 +--num_workers=10 +--disk_size_gb=50 +--autoscaling_algorithm=NONE +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--requirements_file=apache_beam/ml/inference/table_row_inference_requirements.txt +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=result_table_row_inference_batch +--input_options={} +--influx_measurement=result_table_row_inference_batch +--mode=batch +--input_file=gs://apache-beam-ml/testing/inputs/table_rows_100k_benchmark.jsonl +# 100k lines × 100 = 10M rows; use 1000 for 100M rows +--input_expand_factor=100 +--model_path=gs://apache-beam-ml/models/sklearn_table_classifier.pkl +--feature_columns=feature1,feature2,feature3,feature4,feature5 +--output_table=apache-beam-testing:beam_run_inference.result_table_row_inference_batch_outputs +--runner=DataflowRunner +--experiments=use_runner_v2 diff --git a/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Stream.txt b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Stream.txt new file mode 100644 index 000000000000..39ce9071840b --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Stream.txt @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +--project=apache-beam-testing +--region=us-central1 +--worker_machine_type=n1-standard-4 +--num_workers=10 +--disk_size_gb=50 +--autoscaling_algorithm=THROUGHPUT_BASED +--max_num_workers=20 +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--requirements_file=apache_beam/ml/inference/table_row_inference_requirements.txt +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=result_table_row_inference_stream +--input_options={} +--influx_measurement=result_table_row_inference_stream +--mode=streaming +--input_subscription=projects/apache-beam-testing/subscriptions/table_row_inference_benchmark +--window_size_sec=60 +--trigger_interval_sec=30 +--timeout_ms=1800000 +--model_path=gs://apache-beam-ml/models/sklearn_table_classifier.pkl +--feature_columns=feature1,feature2,feature3,feature4,feature5 +--output_table=apache-beam-testing:beam_run_inference.result_table_row_inference_stream_outputs +--runner=DataflowRunner +--experiments=use_runner_v2 diff --git a/.test-infra/tools/refresh_looker_metrics.py b/.test-infra/tools/refresh_looker_metrics.py index a4c6999be775..afd8ffa6f861 100644 --- a/.test-infra/tools/refresh_looker_metrics.py +++ b/.test-infra/tools/refresh_looker_metrics.py @@ -43,9 +43,10 @@ ("82", ["263", "264", "265", "266", "267"]), # PyTorch Sentiment Streaming DistilBERT base uncased ("85", ["268", "269", "270", "271", "272"]), # PyTorch Sentiment Batch DistilBERT base uncased ("86", ["284", "285", "286", "287", "288"]), # VLLM Batch Gemma + ("96", ["270", "304", "305", "353", "354"]), # Table Row Inference Sklearn Batch + ("106", ["355", "356", "357", "358", "359"]) # Table Row Inference Sklearn Streaming ] - def get_look(id: str) -> models.Look: look = next(iter(sdk.search_looks(id=id)), None) if not look: diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md index e0367ea69384..ebef666bec8c 100644 --- a/sdks/python/apache_beam/examples/inference/README.md +++ b/sdks/python/apache_beam/examples/inference/README.md @@ -968,4 +968,70 @@ and produce the following result in your output file location: An emperor penguin is an adorable creature that lives in Antarctica. ``` +--- +## Table row inference + +[`table_row_inference.py`](./table_row_inference.py) contains an implementation for a RunInference pipeline that processes structured table rows from a file or Pub/Sub, runs ML inference while preserving the table schema, and writes results to BigQuery. It supports both batch (file input) and streaming (Pub/Sub) modes. + +### Prerequisites for table row inference + +Install dependencies (or use `apache_beam/ml/inference/table_row_inference_requirements.txt` from the `sdks/python` directory): + +```sh +pip install apache-beam[gcp] scikit-learn google-cloud-pubsub +``` + +For streaming mode you need a Pub/Sub topic and subscription, a BigQuery dataset, and a GCS bucket for model and temp files. + +### Model and data for table row inference + +1. Create a scikit-learn model and sample data using the provided utilities: + +```sh +python -m apache_beam.examples.inference.table_row_inference_utils --action=create_model --output_path=model.pkl --num_features=3 +python -m apache_beam.examples.inference.table_row_inference_utils --action=generate_data --output_path=input_data.jsonl --num_rows=1000 --num_features=3 +``` + +2. Input data should be JSONL with an `id` field and feature columns, for example: + +```json +{"id": "row_1", "feature1": 1.5, "feature2": 2.3, "feature3": 3.7} +``` + +### Running `table_row_inference.py` (batch) + +To run the table row inference pipeline in batch mode locally: + +```sh +python -m apache_beam.examples.inference.table_row_inference \ + --mode=batch \ + --input_file=input_data.jsonl \ + --output_table=PROJECT:DATASET.predictions \ + --model_path=model.pkl \ + --feature_columns=feature1,feature2,feature3 \ + --runner=DirectRunner +``` + +### Running `table_row_inference.py` (streaming) + +For streaming mode, use a Pub/Sub subscription and DataflowRunner. Set up a topic and subscription first, then run: + +```sh +python -m apache_beam.examples.inference.table_row_inference \ + --mode=streaming \ + --input_subscription=projects/PROJECT/subscriptions/SUBSCRIPTION \ + --output_table=PROJECT:DATASET.predictions \ + --model_path=gs://BUCKET/model.pkl \ + --feature_columns=feature1,feature2,feature3 \ + --runner=DataflowRunner \ + --project=PROJECT \ + --region=us-central1 \ + --temp_location=gs://BUCKET/temp \ + --staging_location=gs://BUCKET/staging +``` + +See the script for full pipeline options (window size, trigger interval, worker settings, etc.). + +Output is written to the BigQuery table with columns such as `row_key`, `prediction`, and the original input feature columns. + --- \ No newline at end of file diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference.py b/sdks/python/apache_beam/examples/inference/table_row_inference.py new file mode 100644 index 000000000000..3608a0ae8a39 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/table_row_inference.py @@ -0,0 +1,337 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A pipeline that uses RunInference to perform inference on continuous table rows. + +This pipeline demonstrates ML Pipelines #18: handling continuous new table rows +with RunInference using table input models. It reads structured data (table rows) +from a streaming source, performs inference while preserving the table schema, +and writes results to a table output. + +The pipeline supports both streaming and batch modes: +- Streaming: Reads from Pub/Sub, applies windowing, writes via streaming inserts +- Batch: Reads from file, processes all data, writes via file loads + +Example usage for streaming: + python table_row_inference.py \ + --mode=streaming \ + --input_subscription=projects/PROJECT/subscriptions/SUBSCRIPTION \ + --output_table=PROJECT:DATASET.TABLE \ + --model_path=gs://BUCKET/model.pkl \ + --feature_columns=feature1,feature2,feature3 \ + --runner=DataflowRunner \ + --project=PROJECT \ + --region=REGION \ + --temp_location=gs://BUCKET/temp + +Example usage for batch: + python table_row_inference.py \ + --mode=batch \ + --input_file=gs://BUCKET/input.jsonl \ + --output_table=PROJECT:DATASET.TABLE \ + --model_path=gs://BUCKET/model.pkl \ + --feature_columns=feature1,feature2,feature3 +""" + +import argparse +import json +import logging +from collections.abc import Iterable +from typing import Any +from typing import Optional + +import apache_beam as beam +import numpy as np +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.sklearn_inference import ModelFileType +from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult + + +class TableRowModelHandler(SklearnModelHandlerNumpy): + """ModelHandler that processes table rows (beam.Row objects) for inference. + + This handler extends SklearnModelHandlerNumpy to work with structured + table data represented as beam.Row objects. It extracts specified feature + columns from the row and converts them to numpy arrays for model input. + + Attributes: + feature_columns: List of column names to extract as features from input rows + """ + def __init__( + self, + model_uri: str, + feature_columns: list[str], + model_file_type: ModelFileType = ModelFileType.PICKLE): + """Initialize the TableRowModelHandler. + + Args: + model_uri: Path to the saved model file (local or GCS) + feature_columns: List of column names to use as model features + model_file_type: Type of model file (PICKLE or JOBLIB) + """ + super().__init__(model_uri=model_uri, model_file_type=model_file_type) + self.feature_columns = feature_columns + + def run_inference( + self, + batch: list[beam.Row], + model: Any, + inference_args: Optional[dict[str, Any]] = None) -> Iterable[PredictionResult]: + """Run inference on a batch of beam.Row objects. + + Args: + batch: List of beam.Row objects containing input features + model: Loaded sklearn model + inference_args: Optional additional arguments for inference + + Yields: + PredictionResult containing the original row and prediction + """ + features_array = [] + for row in batch: + row_dict = row._asdict() + features = [row_dict[col] for col in self.feature_columns] + features_array.append(features) + + features_array = np.array(features_array, dtype=np.float32) + predictions = model.predict(features_array) + + for row, prediction in zip(batch, predictions): + yield PredictionResult( + example=row, inference=float(prediction), model_id=self._model_uri) + + +class FormatTableOutput(beam.DoFn): + """DoFn that formats inference results into table output schema. + + Takes PredictionResult objects from KeyedModelHandler and formats them + into dictionaries suitable for writing to BigQuery or other table outputs. + """ + def __init__(self, feature_columns: list[str]): + self.feature_columns = feature_columns + + def process( + self, element: tuple[str, PredictionResult]) -> Iterable[dict[str, Any]]: + """Process a keyed inference result into table output format. + + Args: + element: Tuple of (row_key, PredictionResult) + + Yields: + Dictionary with all input fields plus prediction and metadata + """ + key, prediction = element + row = prediction.example + row_dict = row._asdict() + output = {'row_key': key, 'prediction': prediction.inference} + + if prediction.model_id: + output['model_id'] = prediction.model_id + + for field_name in self.feature_columns: + output[f'input_{field_name}'] = row_dict[field_name] + + yield output + + +def parse_json_to_table_row( + message: bytes, + schema_fields: Optional[list[str]] = None) -> tuple[str, beam.Row]: + """Parse JSON message to (key, beam.Row) format for KeyedModelHandler. + + Args: + message: JSON-encoded bytes + schema_fields: Optional list of expected field names + + Returns: + Tuple of (unique_key, beam.Row with parsed data) + """ + data = json.loads(message.decode('utf-8')) + + row_key = data.get('id', str(hash(message))) + + row_fields = {} + for key, value in data.items(): + if key != 'id' and (schema_fields is None or key in schema_fields): + if isinstance(value, (int, float)): + row_fields[key] = float(value) + else: + row_fields[key] = value + + table_row = beam.Row(**row_fields) + return row_key, table_row + + +def build_output_schema(feature_columns: list[str]) -> str: + """Build BigQuery schema string for output table. + + Args: + feature_columns: List of feature column names + + Returns: + BigQuery schema string + """ + schema_parts = ['row_key:STRING', 'prediction:FLOAT', 'model_id:STRING'] + + for col in feature_columns: + schema_parts.append(f'input_{col}:FLOAT') + + return ','.join(schema_parts) + + +def parse_known_args(argv): + """Parse command-line arguments for the pipeline.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + '--mode', + default='batch', + choices=['streaming', 'batch'], + help='Pipeline mode: streaming or batch') + parser.add_argument( + '--input_subscription', + help='Pub/Sub subscription for streaming mode ' + '(format: projects/PROJECT/subscriptions/SUBSCRIPTION)') + parser.add_argument( + '--input_file', + help='Input file path for batch mode (e.g., gs://bucket/input.jsonl)') + parser.add_argument( + '--output_table', + help='BigQuery output table (format: PROJECT:DATASET.TABLE)') + parser.add_argument( + '--model_path', help='Path to saved model file') + parser.add_argument( + '--feature_columns', + help='Comma-separated list of feature column names') + parser.add_argument( + '--window_size_sec', + type=int, + default=60, + help='Window size in seconds for streaming mode (default: 60)') + parser.add_argument( + '--trigger_interval_sec', + type=int, + default=30, + help='Trigger interval in seconds for streaming mode (default: 30)') + parser.add_argument( + '--input_expand_factor', + type=int, + default=1, + help='In batch mode: repeat each input line this many times to scale up ' + 'volume (e.g. 100k lines × 100 = 10M rows). Default 1 = no expansion.') + return parser.parse_known_args(argv) + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + """Main pipeline execution function. + + Args: + argv: Command-line arguments + save_main_session: Whether to save main session for workers + test_pipeline: Optional test pipeline (for testing) + + Returns: + PipelineResult from pipeline execution + """ + known_args, pipeline_args = parse_known_args(argv) + + if known_args.mode == 'streaming' and not known_args.input_subscription: + raise ValueError('input_subscription is required for streaming mode') + if known_args.mode == 'batch' and not known_args.input_file: + raise ValueError('input_file is required for batch mode') + + feature_columns = [ + col.strip() for col in known_args.feature_columns.split(',') + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(StandardOptions).streaming = ( + known_args.mode == 'streaming') + + model_handler = TableRowModelHandler( + model_uri=known_args.model_path, feature_columns=feature_columns) + + output_schema = build_output_schema(feature_columns) + + pipeline = test_pipeline or beam.Pipeline(options=pipeline_options) + + if known_args.mode == 'streaming': + input_data = ( + pipeline + | 'ReadFromPubSub' >> + beam.io.ReadFromPubSub(subscription=known_args.input_subscription) + | 'ParseToTableRows' >> beam.Map( + lambda msg: parse_json_to_table_row(msg, feature_columns)) + | 'WindowedData' >> beam.WindowInto( + beam.window.FixedWindows(known_args.window_size_sec), + trigger=beam.trigger.AfterProcessingTime( + known_args.trigger_interval_sec), + accumulation_mode=beam.trigger.AccumulationMode.DISCARDING, + allowed_lateness=0)) + write_method = beam.io.WriteToBigQuery.Method.STREAMING_INSERTS + else: + read_lines = ( + pipeline + | 'ReadFromFile' >> beam.io.ReadFromText(known_args.input_file)) + expand_factor = getattr( + known_args, 'input_expand_factor', 1) or 1 + if expand_factor > 1: + read_lines = ( + read_lines + | 'ExpandInput' >> beam.FlatMap( + lambda line: [line] * expand_factor)) + input_data = ( + read_lines + | 'ParseToTableRows' >> beam.Map( + lambda line: parse_json_to_table_row( + line.encode('utf-8'), feature_columns))) + write_method = beam.io.WriteToBigQuery.Method.FILE_LOADS + + write_disposition = ( + beam.io.BigQueryDisposition.WRITE_APPEND + if known_args.mode == 'streaming' + else beam.io.BigQueryDisposition.WRITE_TRUNCATE) + _ = ( + input_data + | 'RunInference' >> RunInference(KeyedModelHandler(model_handler)) + | 'FormatOutput' >> beam.ParDo(FormatTableOutput(feature_columns)) + | 'WriteToBigQuery' >> beam.io.WriteToBigQuery( + known_args.output_table, + schema=output_schema, + write_disposition=write_disposition, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + method=write_method)) + + result = pipeline.run() + + if known_args.mode == 'batch' and not test_pipeline: + result.wait_until_finish() + + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py b/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py new file mode 100644 index 000000000000..57ded48db5bf --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch inference pipeline for table rows using RunInference. + +This is a simplified batch-only implementation of ML Pipelines #18. +It reads table data from files, runs ML inference, and writes results. + +Key Features: +- BATCH PROCESSING ONLY (no streaming complexity) +- Reads from files (JSONL, CSV, or custom) +- Preserves table schema +- Writes to BigQuery or files +- Simple and easy to understand + +Example usage: + + # Basic usage with local files + python table_row_inference_batch.py \ + --input_file=data.jsonl \ + --output_table=project:dataset.table \ + --model_path=model.pkl \ + --feature_columns=feature1,feature2,feature3 + + # With Dataflow + python table_row_inference_batch.py \ + --input_file=gs://bucket/data.jsonl \ + --output_table=project:dataset.table \ + --model_path=gs://bucket/model.pkl \ + --feature_columns=feature1,feature2,feature3 \ + --runner=DataflowRunner \ + --project=PROJECT \ + --region=us-central1 \ + --temp_location=gs://bucket/temp + + # Output to file instead of BigQuery + python table_row_inference_batch.py \ + --input_file=data.jsonl \ + --output_file=predictions.jsonl \ + --model_path=model.pkl \ + --feature_columns=feature1,feature2,feature3 +""" + +import argparse +import json +import logging +from collections.abc import Iterable +from typing import Any +from typing import Optional + +import apache_beam as beam +import numpy as np +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions + + +class BatchTableRowModelHandler(SklearnModelHandlerNumpy): + """ModelHandler for batch processing of table rows. + + This handler is optimized for batch inference on structured table data. + It extracts specified feature columns and runs inference in batches. + """ + def __init__(self, model_uri: str, feature_columns: list[str]): + """Initialize the batch model handler. + + Args: + model_uri: Path to the saved model (local or GCS) + feature_columns: List of column names to use as features + """ + super().__init__(model_uri=model_uri) + self.feature_columns = feature_columns + logging.info( + f'Initialized BatchTableRowModelHandler with features: {feature_columns}' + ) + + def run_inference( + self, + batch: list[beam.Row], + model: Any, + inference_args: Optional[dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Run batch inference on table rows. + + Args: + batch: List of beam.Row objects with table data + model: Loaded scikit-learn model + inference_args: Optional inference arguments (unused) + + Yields: + PredictionResult for each input row + """ + features_list = [] + for row in batch: + row_dict = row._asdict() + features = [row_dict.get(col, 0.0) for col in self.feature_columns] + features_list.append(features) + + features_array = np.array(features_list, dtype=np.float32) + + predictions = model.predict(features_array) + + for row, prediction in zip(batch, predictions): + yield PredictionResult( + example=row, inference=float(prediction), model_id=self._model_uri) + + +class FormatBatchOutput(beam.DoFn): + """Format inference results for batch output.""" + def __init__(self, feature_columns: list[str], include_metadata: bool = True): + """Initialize formatter. + + Args: + feature_columns: List of feature column names to include + include_metadata: Whether to include model_id in output + """ + self.feature_columns = feature_columns + self.include_metadata = include_metadata + + def process( + self, element: tuple[str, PredictionResult]) -> Iterable[dict[str, Any]]: + """Format a keyed inference result. + + Args: + element: Tuple of (row_key, PredictionResult) + + Yields: + Dictionary with formatted output + """ + key, prediction = element + row = prediction.example + row_dict = row._asdict() + + output = {'id': key, 'prediction': prediction.inference} + + if self.include_metadata and prediction.model_id: + output['model_id'] = prediction.model_id + + for field_name in self.feature_columns: + output[field_name] = row_dict[field_name] + + yield output + + +def parse_jsonl_line(line: str, schema_fields: list[str]) -> tuple[str, beam.Row]: + """Parse a JSONL line to (key, beam.Row) format. + + Args: + line: JSON string + schema_fields: Expected field names + + Returns: + Tuple of (row_id, beam.Row) + """ + data = json.loads(line) + + row_id = data.get('id', str(hash(line))) + + row_fields = {} + for field in schema_fields: + if field in data: + value = data[field] + row_fields[field] = float(value) if isinstance(value, (int, float)) else value + + return row_id, beam.Row(**row_fields) + + +def build_bigquery_schema(feature_columns: list[str]) -> str: + """Build BigQuery schema for output table. + + Args: + feature_columns: List of feature column names + + Returns: + BigQuery schema string + """ + schema_fields = ['id:STRING', 'prediction:FLOAT', 'model_id:STRING'] + + for col in feature_columns: + schema_fields.append(f'{col}:FLOAT') + + return ','.join(schema_fields) + + +def run_batch_inference( + input_file: str, + model_path: str, + feature_columns: list[str], + output_table: Optional[str] = None, + output_file: Optional[str] = None, + pipeline_options: Optional[PipelineOptions] = None) -> beam.Pipeline: + """Run batch inference pipeline. + + Args: + input_file: Path to input file (JSONL format) + model_path: Path to saved model + feature_columns: List of feature column names + output_table: Optional BigQuery table (PROJECT:DATASET.TABLE) + output_file: Optional output file path + pipeline_options: Beam pipeline options + + Returns: + Executed pipeline + """ + if not output_table and not output_file: + raise ValueError('Must specify either output_table or output_file') + + pipeline_options = pipeline_options or PipelineOptions() + + model_handler = BatchTableRowModelHandler( + model_uri=model_path, feature_columns=feature_columns) + + logging.info(f'Starting batch inference pipeline') + logging.info(f' Input: {input_file}') + logging.info(f' Model: {model_path}') + logging.info(f' Features: {feature_columns}') + logging.info( + f' Output: {output_table if output_table else output_file}') + + with beam.Pipeline(options=pipeline_options) as pipeline: + + input_data = ( + pipeline + | 'ReadInputFile' >> beam.io.ReadFromText(input_file) + | 'ParseToRows' >> beam.Map( + lambda line: parse_jsonl_line(line, feature_columns))) + + predictions = ( + input_data + | 'RunInference' >> RunInference(KeyedModelHandler(model_handler)) + | 'FormatOutput' >> beam.ParDo(FormatBatchOutput(feature_columns))) + + if output_table: + schema = build_bigquery_schema(feature_columns) + _ = ( + predictions + | 'WriteToBigQuery' >> beam.io.WriteToBigQuery( + output_table, + schema=schema, + write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + method=beam.io.WriteToBigQuery.Method.FILE_LOADS)) + + if output_file: + _ = ( + predictions + | 'FormatJSON' >> beam.Map(json.dumps) + | 'WriteToFile' >> beam.io.WriteToText( + output_file, file_name_suffix='.jsonl', shard_name_template='')) + + logging.info('Batch inference pipeline completed successfully') + return pipeline + + +def main(argv=None): + """Main entry point for the batch inference pipeline.""" + parser = argparse.ArgumentParser( + description='Batch inference on table rows using RunInference', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__) + + parser.add_argument( + '--input_file', + required=True, + help='Input file path (JSONL format). Can be local or GCS path.') + + parser.add_argument( + '--model_path', + required=True, + help='Path to saved model file. Can be local or GCS path.') + + parser.add_argument( + '--feature_columns', + required=True, + help='Comma-separated list of feature column names to extract from input rows.' + ) + + parser.add_argument( + '--output_table', + help='BigQuery output table in format PROJECT:DATASET.TABLE') + + parser.add_argument( + '--output_file', + help='Output file path (JSONL format). Alternative to output_table.') + + known_args, pipeline_args = parser.parse_known_args(argv) + + if not known_args.output_table and not known_args.output_file: + parser.error('Must specify either --output_table or --output_file') + + feature_columns = [col.strip() for col in known_args.feature_columns.split(',')] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = True + + run_batch_inference( + input_file=known_args.input_file, + model_path=known_args.model_path, + feature_columns=feature_columns, + output_table=known_args.output_table, + output_file=known_args.output_file, + pipeline_options=pipeline_options) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + main() diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py new file mode 100644 index 000000000000..205ee34deabd --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py @@ -0,0 +1,190 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for table row inference pipeline.""" + +import json +import os +import pickle +import tempfile +import unittest + +import apache_beam as beam +import numpy as np +from apache_beam.examples.inference.table_row_inference import FormatTableOutput +from apache_beam.examples.inference.table_row_inference import TableRowModelHandler +from apache_beam.examples.inference.table_row_inference import build_output_schema +from apache_beam.examples.inference.table_row_inference import parse_json_to_table_row +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +try: + from sklearn.linear_model import LinearRegression + SKLEARN_AVAILABLE = True +except ImportError: + SKLEARN_AVAILABLE = False + + +class SimpleLinearModel: + """Simple model for testing without sklearn dependency.""" + def predict(self, X): + return np.sum(X, axis=1) + + +@unittest.skipIf(not SKLEARN_AVAILABLE, 'sklearn is not available') +class TableRowInferenceTest(unittest.TestCase): + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.model_path = os.path.join(self.tmp_dir, 'test_model.pkl') + + model = LinearRegression() + X_train = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + y_train = np.array([6, 15, 24]) + model.fit(X_train, y_train) + + with open(self.model_path, 'wb') as f: + pickle.dump(model, f) + + def test_parse_json_to_table_row(self): + json_data = json.dumps({ + 'id': 'test_1', + 'feature1': 1.0, + 'feature2': 2.0, + 'feature3': 3.0 + }).encode('utf-8') + + key, row = parse_json_to_table_row( + json_data, schema_fields=['feature1', 'feature2', 'feature3']) + + self.assertEqual(key, 'test_1') + self.assertEqual(row.feature1, 1.0) + self.assertEqual(row.feature2, 2.0) + self.assertEqual(row.feature3, 3.0) + + def test_build_output_schema(self): + feature_cols = ['feature1', 'feature2', 'feature3'] + schema = build_output_schema(feature_cols) + + expected_fields = [ + 'row_key', 'prediction', 'model_id', 'input_feature1', + 'input_feature2', 'input_feature3' + ] + + for field in expected_fields: + self.assertIn(field, schema) + + def test_table_row_model_handler(self): + model_handler = TableRowModelHandler( + model_uri=self.model_path, feature_columns=['f1', 'f2', 'f3']) + + model = model_handler.load_model() + + test_rows = [ + beam.Row(f1=1.0, f2=2.0, f3=3.0), + beam.Row(f1=4.0, f2=5.0, f3=6.0), + ] + + results = list(model_handler.run_inference(test_rows, model)) + + self.assertEqual(len(results), 2) + self.assertIsInstance(results[0], PredictionResult) + self.assertEqual(results[0].example, test_rows[0]) + self.assertIsNotNone(results[0].inference) + + def test_format_table_output(self): + row = beam.Row(feature1=1.0, feature2=2.0, feature3=3.0) + prediction_result = PredictionResult( + example=row, inference=6.0, model_id='test_model') + + keyed_result = ('test_key', prediction_result) + + feature_columns = ['feature1', 'feature2', 'feature3'] + formatter = FormatTableOutput(feature_columns=feature_columns) + outputs = list(formatter.process(keyed_result)) + + self.assertEqual(len(outputs), 1) + output = outputs[0] + + self.assertEqual(output['row_key'], 'test_key') + self.assertEqual(output['prediction'], 6.0) + self.assertEqual(output['model_id'], 'test_model') + self.assertEqual(output['input_feature1'], 1.0) + self.assertEqual(output['input_feature2'], 2.0) + self.assertEqual(output['input_feature3'], 3.0) + + def test_pipeline_integration(self): + test_data = [ + json.dumps({ + 'id': 'row_1', 'feature1': 1.0, 'feature2': 2.0, 'feature3': 3.0 + }), + json.dumps({ + 'id': 'row_2', 'feature1': 4.0, 'feature2': 5.0, 'feature3': 6.0 + }), + ] + + feature_columns = ['feature1', 'feature2', 'feature3'] + model_handler = TableRowModelHandler( + model_uri=self.model_path, feature_columns=feature_columns) + + with TestPipeline() as p: + input_data = ( + p + | beam.Create(test_data) + | beam.Map( + lambda line: parse_json_to_table_row( + line.encode('utf-8'), feature_columns))) + + predictions = ( + input_data + | RunInference(KeyedModelHandler(model_handler)) + | beam.ParDo(FormatTableOutput(feature_columns=feature_columns))) + + def check_outputs(outputs): + self.assertEqual(len(outputs), 2) + + for output in outputs: + self.assertIn('row_key', output) + self.assertIn('prediction', output) + self.assertIn('input_feature1', output) + self.assertIn('input_feature2', output) + self.assertIn('input_feature3', output) + + assert_that(predictions, check_outputs) + + +class TableRowInferenceNoSklearnTest(unittest.TestCase): + """Tests that don't require sklearn.""" + def test_parse_json_without_schema(self): + json_data = json.dumps({'id': 'test', 'value': 123}).encode('utf-8') + + key, row = parse_json_to_table_row(json_data) + + self.assertEqual(key, 'test') + self.assertTrue(hasattr(row, 'value')) + + +if __name__ == '__main__': + unittest.main() + + + + + diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py b/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py new file mode 100644 index 000000000000..0c5f9814ca2f --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py @@ -0,0 +1,297 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Utility functions for table row inference pipeline. + +This module provides helper functions for testing and deploying the +table row inference pipeline, including data generation, model creation, +and Pub/Sub resource management. +""" + +import json +import logging +import pickle +import time +from typing import Optional + +try: + from google.cloud import pubsub_v1 + PUBSUB_AVAILABLE = True +except ImportError: + PUBSUB_AVAILABLE = False + logging.warning('google-cloud-pubsub not available') + +try: + import numpy as np + from sklearn.linear_model import LinearRegression + SKLEARN_AVAILABLE = True +except ImportError: + SKLEARN_AVAILABLE = False + logging.warning('sklearn not available') + + +def create_sample_model(output_path: str, num_features: int = 3): + """Create and save a simple linear regression model for testing. + + Args: + output_path: Path where to save the model (local or GCS) + num_features: Number of input features + """ + if not SKLEARN_AVAILABLE: + raise ImportError('sklearn is required to create sample models') + + model = LinearRegression() + + X_train = np.random.randn(100, num_features) + y_train = np.sum(X_train, axis=1) + np.random.randn(100) * 0.1 + + model.fit(X_train, y_train) + + with open(output_path, 'wb') as f: + pickle.dump(model, f) + + logging.info(f'Sample model saved to {output_path}') + + +def generate_sample_data( + num_rows: int = 100, num_features: int = 3) -> list[dict]: + """Generate sample table row data for testing. + + Args: + num_rows: Number of rows to generate + num_features: Number of features per row + + Returns: + List of dictionaries representing table rows + """ + if not SKLEARN_AVAILABLE: + raise ImportError('numpy is required to generate sample data') + + data = [] + for i in range(num_rows): + row = {'id': f'row_{i}', 'timestamp': time.time()} + + for j in range(num_features): + row[f'feature{j+1}'] = float(np.random.randn()) + + data.append(row) + + return data + + +def write_data_to_file(data: list[dict], output_path: str): + """Write sample data to JSONL file. + + Args: + data: List of data dictionaries + output_path: Output file path + """ + with open(output_path, 'w') as f: + for row in data: + f.write(json.dumps(row) + '\n') + + logging.info(f'Wrote {len(data)} rows to {output_path}') + + +def publish_to_pubsub( + project: str, + topic_name: str, + data: list[dict], + rate_limit: Optional[float] = None): + """Publish sample data to Pub/Sub topic. + + Args: + project: GCP project ID + topic_name: Pub/Sub topic name + data: List of data dictionaries to publish + rate_limit: Optional rate limit (rows per second) + """ + if not PUBSUB_AVAILABLE: + raise ImportError('google-cloud-pubsub is required for Pub/Sub operations') + + publisher = pubsub_v1.PublisherClient() + topic_path = publisher.topic_path(project, topic_name) + + delay = 1.0 / rate_limit if rate_limit else 0 + + for i, row in enumerate(data): + message = json.dumps(row).encode('utf-8') + future = publisher.publish(topic_path, message) + + if (i + 1) % 100 == 0: + logging.info(f'Published {i + 1} messages') + + if delay > 0: + time.sleep(delay) + + logging.info(f'Published {len(data)} messages to {topic_path}') + + +def ensure_pubsub_topic(project: str, topic_name: str) -> str: + """Create Pub/Sub topic if it doesn't exist. + + Args: + project: GCP project ID + topic_name: Pub/Sub topic name + + Returns: + Full topic path + """ + if not PUBSUB_AVAILABLE: + raise ImportError('google-cloud-pubsub is required for Pub/Sub operations') + + publisher = pubsub_v1.PublisherClient() + topic_path = publisher.topic_path(project, topic_name) + + try: + publisher.get_topic(request={'topic': topic_path}) + logging.info(f'Topic {topic_name} already exists') + except Exception: + publisher.create_topic(name=topic_path) + logging.info(f'Created topic {topic_name}') + + return topic_path + + +def ensure_pubsub_subscription( + project: str, topic_name: str, subscription_name: str) -> str: + """Create Pub/Sub subscription if it doesn't exist. + + Args: + project: GCP project ID + topic_name: Pub/Sub topic name + subscription_name: Subscription name + + Returns: + Full subscription path + """ + if not PUBSUB_AVAILABLE: + raise ImportError('google-cloud-pubsub is required for Pub/Sub operations') + + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_path = publisher.topic_path(project, topic_name) + subscription_path = subscriber.subscription_path(project, subscription_name) + + try: + subscriber.get_subscription(request={'subscription': subscription_path}) + logging.info(f'Subscription {subscription_name} already exists') + except Exception: + subscriber.create_subscription( + name=subscription_path, topic=topic_path) + logging.info(f'Created subscription {subscription_name}') + + return subscription_path + + +def cleanup_pubsub_resources( + project: str, topic_name: str, subscription_name: Optional[str] = None): + """Delete Pub/Sub topic and optionally subscription. + + Args: + project: GCP project ID + topic_name: Pub/Sub topic name + subscription_name: Optional subscription name to delete + """ + if not PUBSUB_AVAILABLE: + raise ImportError('google-cloud-pubsub is required for Pub/Sub operations') + + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + if subscription_name: + subscription_path = subscriber.subscription_path(project, subscription_name) + try: + subscriber.delete_subscription(request={'subscription': subscription_path}) + logging.info(f'Deleted subscription {subscription_name}') + except Exception as e: + logging.warning(f'Failed to delete subscription: {e}') + + topic_path = publisher.topic_path(project, topic_name) + try: + publisher.delete_topic(request={'topic': topic_path}) + logging.info(f'Deleted topic {topic_name}') + except Exception as e: + logging.warning(f'Failed to delete topic: {e}') + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + import argparse + parser = argparse.ArgumentParser( + description='Utility for table row inference pipeline') + parser.add_argument( + '--action', + required=True, + choices=[ + 'create_model', 'generate_data', 'publish_data', 'create_topic', + 'create_subscription', 'cleanup' + ], + help='Action to perform') + parser.add_argument('--project', help='GCP project ID') + parser.add_argument('--topic', help='Pub/Sub topic name') + parser.add_argument('--subscription', help='Pub/Sub subscription name') + parser.add_argument('--output_path', help='Output path for model or data') + parser.add_argument( + '--num_rows', type=int, default=100, help='Number of rows to generate') + parser.add_argument( + '--num_features', + type=int, + default=3, + help='Number of features per row') + parser.add_argument( + '--rate_limit', + type=float, + help='Rate limit for publishing (rows/sec)') + + args = parser.parse_args() + + if args.action == 'create_model': + if not args.output_path: + raise ValueError('--output_path required for create_model') + create_sample_model(args.output_path, args.num_features) + + elif args.action == 'generate_data': + if not args.output_path: + raise ValueError('--output_path required for generate_data') + data = generate_sample_data(args.num_rows, args.num_features) + write_data_to_file(data, args.output_path) + + elif args.action == 'publish_data': + if not args.project or not args.topic: + raise ValueError('--project and --topic required for publish_data') + data = generate_sample_data(args.num_rows, args.num_features) + publish_to_pubsub(args.project, args.topic, data, args.rate_limit) + + elif args.action == 'create_topic': + if not args.project or not args.topic: + raise ValueError('--project and --topic required for create_topic') + ensure_pubsub_topic(args.project, args.topic) + + elif args.action == 'create_subscription': + if not args.project or not args.topic or not args.subscription: + raise ValueError( + '--project, --topic, and --subscription required for create_subscription' + ) + ensure_pubsub_subscription(args.project, args.topic, args.subscription) + + elif args.action == 'cleanup': + if not args.project or not args.topic: + raise ValueError('--project and --topic required for cleanup') + cleanup_pubsub_resources(args.project, args.topic, args.subscription) diff --git a/sdks/python/apache_beam/ml/inference/table_row_inference_requirements.txt b/sdks/python/apache_beam/ml/inference/table_row_inference_requirements.txt new file mode 100644 index 000000000000..81770be263b5 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/table_row_inference_requirements.txt @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +scikit-learn>=1.0.0,<1.6.0 +numpy>=1.25.0,<2.5.0 +google-cloud-monitoring>=2.27.0 +protobuf>=4.25.1 +requests>=2.31.0 diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/README.md b/sdks/python/apache_beam/testing/benchmarks/inference/README.md index b76fdfa8ec5c..ccfdef2060b8 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/README.md +++ b/sdks/python/apache_beam/testing/benchmarks/inference/README.md @@ -118,6 +118,21 @@ Full pipeline implementation is available [here](https://github.com/apache/beam/ Full pipeline implementation is available [here](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/vllm_gemma_batch.py). +## Table Row Inference (Sklearn) + +The Table Row Inference benchmark runs batch and streaming pipelines using a Sklearn +table classifier. Required GCS artifacts (shared bucket): + +- **Staging/temp**: `gs://temp-storage-for-perf-tests/loadtests` +- **Batch input**: `gs://apache-beam-ml/testing/inputs/table_rows_100k_benchmark.jsonl` +- **Model**: `gs://apache-beam-ml/models/sklearn_table_classifier.pkl` + +Streaming uses a Pub/Sub subscription (e.g. `projects/apache-beam-testing/subscriptions/table_row_inference_benchmark`). +Pipeline options files: `beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Batch.txt` and +`beam_Inference_Python_Benchmarks_Dataflow_Table_Row_Inference_Stream.txt`. + +Full pipeline implementation is available [here](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/table_row_inference.py). + ## How to add a new ML benchmark pipeline 1. Create the pipeline implementation diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py new file mode 100644 index 000000000000..5ddf3cf40306 --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Benchmark test for table row inference pipeline. + +This benchmark measures the performance of RunInference with continuous +table row inputs, including throughput, latency, and cost metrics. +""" + +import logging + +from apache_beam.examples.inference import table_row_inference +from apache_beam.options.pipeline_options import DebugOptions +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.options.pipeline_options import WorkerOptions +from apache_beam.testing.load_tests.dataflow_cost_benchmark import DataflowCostBenchmark +from apache_beam.testing.load_tests.load_test import LoadTestOptions + + +class TableRowInferenceOptions( + LoadTestOptions, + StandardOptions, + GoogleCloudOptions, + WorkerOptions, + DebugOptions, + SetupOptions, +): + + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument('--mode', default='batch') + parser.add_argument('--input_subscription') + parser.add_argument('--input_file') + parser.add_argument('--output_table') + parser.add_argument('--model_path') + parser.add_argument('--feature_columns') + parser.add_argument('--window_size_sec', type=int, default=60) + parser.add_argument('--trigger_interval_sec', type=int, default=30) + parser.add_argument('--input_expand_factor', type=int, default=1) + + +class TableRowInferenceBenchmarkTest(DataflowCostBenchmark): + """Benchmark for continuous table row inference with RunInference. + + This benchmark measures: + - Mean Inference Batch Size: Average batch size for inference + - Mean Inference Batch Latency: Average time per batch inference + - Mean Load Model Latency: Time to load the model + - Throughput: Elements processed per second + - Cost: Estimated cost on Dataflow + """ + options_class = TableRowInferenceOptions + + def __init__(self): + self.metrics_namespace = 'BeamML_TableInference' + super().__init__( + metrics_namespace=self.metrics_namespace, + is_streaming=False, + pcollection='RunInference/BeamML_RunInference_Postprocess-0.out0') + self.is_streaming = ( + (self.pipeline.get_option('mode') or 'batch') == 'streaming') + if self.is_streaming: + self.subscription = self.pipeline.get_option('input_subscription') + + def test(self): + """Execute the table row inference pipeline for benchmarking.""" + extra_opts = {} + + mode = self.pipeline.get_option('mode') or 'batch' + extra_opts['mode'] = mode + + if mode == 'streaming': + extra_opts['input_subscription'] = self.pipeline.get_option( + 'input_subscription') + extra_opts['window_size_sec'] = int( + self.pipeline.get_option('window_size_sec') or 60) + extra_opts['trigger_interval_sec'] = int( + self.pipeline.get_option('trigger_interval_sec') or 30) + else: + extra_opts['input_file'] = self.pipeline.get_option('input_file') + + for opt in ['output_table', 'model_path', 'feature_columns']: + val = self.pipeline.get_option(opt) + if val: + extra_opts[opt] = val + + self.result = table_row_inference.run( + self.pipeline.get_full_options_as_args(**extra_opts), + test_pipeline=self.pipeline) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + TableRowInferenceBenchmarkTest().run() diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py index 87af2ef6a507..ba621fb6fdf6 100644 --- a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py @@ -45,35 +45,28 @@ class DataflowCostBenchmark(LoadTest): If using InfluxDB with Basic HTTP authentication enabled, provide the following environment options: `INFLUXDB_USER` and `INFLUXDB_USER_PASSWORD`. - If the hardware configuration for the job includes use of a GPU, please + If the hardware configuration for the job includes use of a GPU, please specify the version in use with the Accelerator enumeration. This is used to calculate the cost of the job later, as different accelerators have different billing rates per hour of use. """ WORKER_START_PATTERN = re.compile( - r'^All workers have finished the startup processes and ' - r'began to receive work requests.*$') - WORKER_STOP_PATTERN = re.compile(r'^Stopping worker pool.*$') + r'All workers have finished the startup processes and ' + r'began to receive work requests') + WORKER_STOP_PATTERN = re.compile(r'Stopping worker pool') def __init__( self, metrics_namespace: Optional[str] = None, is_streaming: bool = False, gpu: Optional[costs.Accelerator] = None, - pcollection: str = 'ProcessOutput.out0'): - """ - Initializes DataflowCostBenchmark. - - Args: - metrics_namespace (Optional[str]): Namespace for metrics. - is_streaming (bool): Whether the pipeline is streaming or batch. - gpu (Optional[costs.Accelerator]): Optional GPU type. - pcollection (str): PCollection name to monitor throughput. - """ + pcollection: str = 'ProcessOutput.out0', + subscription: Optional[str] = None): self.is_streaming = is_streaming self.gpu = gpu self.pcollection = pcollection + self.subscription = subscription super().__init__(metrics_namespace=metrics_namespace) self.dataflow_client = DataflowApplicationClient( self.pipeline.get_pipeline_options()) @@ -84,8 +77,8 @@ def run(self) -> None: self.test() if not hasattr(self, 'result'): self.result = self.pipeline.run() - state = self.result.wait_until_finish(duration=self.timeout_ms) - assert state != PipelineState.FAILED + state = self.result.wait_until_finish(duration=self.timeout_ms) + assert state != PipelineState.FAILED logging.info( 'Pipeline complete, sleeping for 4 minutes to allow resource ' @@ -98,6 +91,8 @@ def run(self) -> None: logging.info(self.extra_metrics) self._metrics_monitor.publish_metrics(self.result, self.extra_metrics) + except Exception as e: + raise finally: self.cleanup() @@ -148,26 +143,52 @@ def _process_metrics_list(self, def _get_worker_time_interval( self, job_id: str) -> tuple[Optional[str], Optional[str]]: """Extracts worker start and stop times from job messages.""" - messages, _ = self.dataflow_client.list_messages( - job_id=job_id, - start_time=None, - end_time=None, - minimum_importance='JOB_MESSAGE_DETAILED') - start_time, end_time = None, None - for message in messages: - text = message.messageText - if text: - if self.WORKER_START_PATTERN.match(text): - start_time = message.time - if self.WORKER_STOP_PATTERN.match(text): - end_time = message.time - + page_token = None + all_messages = [] + last_message_time = None + while True: + messages, page_token = self.dataflow_client.list_messages( + job_id=job_id, + start_time=None, + end_time=None, + page_token=page_token, + minimum_importance='JOB_MESSAGE_DEBUG') + for message in messages: + text = message.messageText + if getattr(message, 'time', None): + last_message_time = message.time + if text: + all_messages.append(text) + if self.WORKER_START_PATTERN.search(text): + start_time = message.time + logging.info('Matched WORKER_START_PATTERN: %r', text) + if self.WORKER_STOP_PATTERN.search(text): + end_time = message.time + logging.info('Matched WORKER_STOP_PATTERN: %r', text) + if not page_token or (start_time and end_time): + break + if start_time and not end_time and self.is_streaming and last_message_time: + end_time = last_message_time + logging.info( + 'Using last job message time as end_time for streaming job: %s', + end_time) + if not start_time or not end_time: + logging.warning( + 'Could not determine worker time interval. ' + 'start_time=%s, end_time=%s, total messages=%d', + start_time, end_time, len(all_messages)) return start_time, end_time def _get_throughput_metrics( - self, project: str, job_id: str, start_time: str, - end_time: str) -> dict[str, float]: + self, + project: str, + job_id: str, + start_time: str, + end_time: str, + pcollection_name: Optional[str] = None) -> dict[str, float]: + """Query Cloud Monitoring for per-PCollection throughput.""" + name = pcollection_name if pcollection_name is not None else self.pcollection interval = monitoring_v3.TimeInterval( start_time=start_time, end_time=end_time) aggregation = monitoring_v3.Aggregation( @@ -178,16 +199,16 @@ def _get_throughput_metrics( "Bytes": monitoring_v3.ListTimeSeriesRequest( name=f"projects/{project}", filter=f'metric.type=' - f'"dataflow.googleapis.com/job/estimated_bytes_produced_count" ' + f'"dataflow.googleapis.com/job/estimated_byte_count" ' f'AND metric.labels.job_id=' - f'"{job_id}" AND metric.labels.pcollection="{self.pcollection}"', + f'"{job_id}" AND metric.labels.pcollection="{name}"', interval=interval, aggregation=aggregation), "Elements": monitoring_v3.ListTimeSeriesRequest( name=f"projects/{project}", filter=f'metric.type="dataflow.googleapis.com/job/element_count" ' f'AND metric.labels.job_id="{job_id}" ' - f'AND metric.labels.pcollection="{self.pcollection}"', + f'AND metric.labels.pcollection="{name}"', interval=interval, aggregation=aggregation) } @@ -204,6 +225,52 @@ def _get_throughput_metrics( return metrics + def _get_streaming_throughput_metrics( + self, + project: str, + start_time: str, + end_time: str) -> dict[str, float]: + if not self.subscription: + return {'AvgThroughputBytes': 0.0, 'AvgThroughputElements': 0.0} + + sub_parts = self.subscription.split('/') + subscription_id = sub_parts[-1] if sub_parts else self.subscription + + interval = monitoring_v3.TimeInterval( + start_time=start_time, end_time=end_time) + aggregation = monitoring_v3.Aggregation( + alignment_period=Duration(seconds=60), + per_series_aligner=monitoring_v3.Aggregation.Aligner.ALIGN_RATE) + + requests = { + "Bytes": monitoring_v3.ListTimeSeriesRequest( + name=f"projects/{project}", + filter=f'metric.type=' + f'"pubsub.googleapis.com/subscription/byte_cost" ' + f'AND resource.labels.subscription_id="{subscription_id}"', + interval=interval, + aggregation=aggregation), + "Elements": monitoring_v3.ListTimeSeriesRequest( + name=f"projects/{project}", + filter=f'metric.type=' + f'"pubsub.googleapis.com/subscription/sent_message_count" ' + f'AND resource.labels.subscription_id="{subscription_id}"', + interval=interval, + aggregation=aggregation), + } + + metrics = {} + for key, req in requests.items(): + time_series = list( + self.monitoring_client.list_time_series(request=req)) + values = [ + point.value.double_value for series in time_series + for point in series.points + ] + avg_rate = sum(values) / len(values) if values else 0.0 + metrics[f"AvgThroughput{key}"] = avg_rate + return metrics + def _get_job_runtime(self, start_time: str, end_time: str) -> float: """Calculates the job runtime duration in seconds.""" start_dt = datetime.fromisoformat(start_time[:-1]) @@ -220,9 +287,20 @@ def _get_additional_metrics(self, logging.warning('Could not find valid worker start/end times.') return {} - throughput_metrics = self._get_throughput_metrics( - project, job_id, start_time, end_time) + runtime_seconds = self._get_job_runtime(start_time, end_time) + if self.is_streaming: + throughput_metrics = self._get_streaming_throughput_metrics( + project, start_time, end_time) + else: + throughput_metrics = self._get_throughput_metrics( + project, job_id, start_time, end_time) + if (throughput_metrics.get('AvgThroughputBytes', 0) == 0 and + throughput_metrics.get('AvgThroughputElements', 0) == 0): + logging.warning( + 'No throughput data for PCollection "%s". Check Dataflow job %s ' + 'graph for actual PCollection names (Runner v2 may use different ' + 'naming).', self.pcollection, job_id) return { **throughput_metrics, - "JobRuntimeSeconds": self._get_job_runtime(start_time, end_time), + "JobRuntimeSeconds": runtime_seconds, } diff --git a/sdks/python/apache_beam/testing/load_tests/load_test.py b/sdks/python/apache_beam/testing/load_tests/load_test.py index 20dea3932b49..960d7ee2ed6e 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test.py @@ -92,8 +92,14 @@ class LoadTest(object): following environment options: `INFLUXDB_USER` and `INFLUXDB_USER_PASSWORD`. """ def __init__(self, metrics_namespace=None): - # Be sure to set blocking to false for timeout_ms to work properly - self.pipeline = TestPipeline(is_integration_test=True, blocking=False) + options_class = getattr(self.__class__, 'options_class', None) + if options_class is not None: + options_list = TestPipeline.get_options_list() + options = options_class(options_list) + self.pipeline = TestPipeline( + options=options, is_integration_test=True, blocking=False) + else: + self.pipeline = TestPipeline(is_integration_test=True, blocking=False) assert not self.pipeline.blocking options = self.pipeline.get_pipeline_options().view_as(LoadTestOptions) diff --git a/sdks/python/apache_beam/testing/test_pipeline.py b/sdks/python/apache_beam/testing/test_pipeline.py index 6a96e32bb929..712da8636234 100644 --- a/sdks/python/apache_beam/testing/test_pipeline.py +++ b/sdks/python/apache_beam/testing/test_pipeline.py @@ -21,6 +21,7 @@ import argparse import shlex +import sys from unittest import SkipTest from apache_beam.internal import pickler @@ -166,6 +167,19 @@ def _parse_test_option_args(self, argv): return shlex.split(test_pipeline_options) \ if test_pipeline_options else [] + @classmethod + def get_options_list(cls, argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '--test-pipeline-options', + type=str, + action='store', + help='Pipeline options for the test') + known, _ = parser.parse_known_args(argv if argv is not None else sys.argv) + opts = known.test_pipeline_options or getattr( + cls, 'pytest_test_pipeline_options', None) + return shlex.split(opts) if opts else [] + def get_full_options_as_args(self, **extra_opts): """Get full pipeline options as an argument list. diff --git a/website/www/site/content/en/performance/_index.md b/website/www/site/content/en/performance/_index.md index 17bdc6f3de0a..1624c58efe2e 100644 --- a/website/www/site/content/en/performance/_index.md +++ b/website/www/site/content/en/performance/_index.md @@ -46,6 +46,7 @@ See the following pages for performance measures recorded when running various B ## Streaming - [PyTorch Sentiment Analysis Streaming DistilBERT base](/performance/pytorchbertsentimentstreaming) +- [Table Row Inference Sklearn Streaming](/performance/tablerowinferencestreaming) ## Batch @@ -57,3 +58,4 @@ See the following pages for performance measures recorded when running various B - [PyTorch Vision Classification Resnet 152 Tesla T4 GPU](/performance/pytorchresnet152tesla) - [TensorFlow MNIST Image Classification](/performance/tensorflowmnist) - [VLLM Gemma Batch Completion Tesla T4 GPU](/performance/vllmgemmabatchtesla) +- [Table Row Inference Sklearn Batch](/performance/tablerowinference) diff --git a/website/www/site/content/en/performance/tablerowinference/_index.md b/website/www/site/content/en/performance/tablerowinference/_index.md new file mode 100644 index 000000000000..464c5cb263c6 --- /dev/null +++ b/website/www/site/content/en/performance/tablerowinference/_index.md @@ -0,0 +1,45 @@ +--- +title: "Table Row Inference Sklearn Batch Performance" +--- + + + +# Table Row Inference Sklearn Batch + +**Model**: Scikit-learn classifier on structured table data (Beam.Row) +**Accelerator**: CPU-based inference (fixed batch size) +**Host**: 10 × n1-standard-4 (4 vCPUs, 15 GB RAM) + +This batch pipeline performs inference on continuous table rows using RunInference with a Scikit-learn model. +It reads structured data (table rows) from GCS in JSONL format, extracts specified feature columns, and runs batched inference while preserving the original table schema. +The pipeline ensures exactly-once semantics within batch execution by deduplicating inputs and writing results to BigQuery using file-based loads, enabling reproducible and comparable performance measurements across runs. + +The following graphs show various metrics when running Table Row Inference Sklearn Batch pipeline. +See the [glossary](/performance/glossary) for definitions. + +Full pipeline implementation is available [here](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/table_row_inference.py). + +## What is the estimated cost to run the pipeline? + +{{< performance_looks io="tablerowinference" read_or_write="write" section="cost" >}} + +## How has various metrics changed when running the pipeline for different Beam SDK versions? + +{{< performance_looks io="tablerowinference" read_or_write="write" section="version" >}} + +## How has various metrics changed over time when running the pipeline? + +{{< performance_looks io="tablerowinference" read_or_write="write" section="date" >}} + +See also [Table Row Inference Sklearn Streaming](/performance/tablerowinferencestreaming) for the streaming variant of this pipeline. + diff --git a/website/www/site/content/en/performance/tablerowinferencestreaming/_index.md b/website/www/site/content/en/performance/tablerowinferencestreaming/_index.md new file mode 100644 index 000000000000..1420e2ce6b9a --- /dev/null +++ b/website/www/site/content/en/performance/tablerowinferencestreaming/_index.md @@ -0,0 +1,43 @@ +--- +title: "Table Row Inference Sklearn Streaming Performance" +--- + + + +# Table Row Inference Sklearn Streaming + +**Model**: Scikit-learn classifier on structured table data (Beam.Row) +**Accelerator**: CPU-based inference (fixed batch size) +**Host**: 10 × n1-standard-4 (4 vCPUs, 15 GB RAM), autoscaling up to 20 workers (THROUGHPUT_BASED) + +This streaming pipeline performs inference on continuous table rows using RunInference with a Scikit-learn model. +It reads messages from Pub/Sub, applies windowing, runs batched inference while preserving the table schema, and writes results to BigQuery via streaming inserts. + +The following graphs show various metrics when running Table Row Inference Sklearn Streaming pipeline. +See the [glossary](/performance/glossary) for definitions. + +Full pipeline implementation is available [here](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/table_row_inference.py). + +## What is the estimated cost to run the pipeline? + +{{< performance_looks io="tablerowinferencestreaming" read_or_write="write" section="cost" >}} + +## How has various metrics changed when running the pipeline for different Beam SDK versions? + +{{< performance_looks io="tablerowinferencestreaming" read_or_write="write" section="version" >}} + +## How has various metrics changed over time when running the pipeline? + +{{< performance_looks io="tablerowinferencestreaming" read_or_write="write" section="date" >}} + +See also [Table Row Inference Sklearn Batch](/performance/tablerowinference) for the batch variant of this pipeline. diff --git a/website/www/site/data/performance.yaml b/website/www/site/data/performance.yaml index 17a6612160c6..3fd9a948abf7 100644 --- a/website/www/site/data/performance.yaml +++ b/website/www/site/data/performance.yaml @@ -250,3 +250,35 @@ looks: title: AvgThroughputBytesPerSec by Version - id: dKyJy5ZKhkBdSTXRY3wZR6fXzptSs2qm title: AvgThroughputElementsPerSec by Version + tablerowinference: + write: + folder: 96 + cost: + - id: Yj3r3VpFDxQwNmPzSsq7wJKf628FxtTg + title: RunTime and EstimatedCost + date: + - id: 82WSkFRcNHm5gbzWdSJFQxDPyGMwFMRj + title: AvgThroughputBytesPerSec by Date + - id: cN8GSph5PfJgTFxhZHgW4KfkvzhnXcFG + title: AvgThroughputElementsPerSec by Date + version: + - id: cxkwmK48MWZWB5bd4DHmGMsMs4VTf4Jd + title: AvgThroughputBytesPerSec by Version + - id: kCbhxxJbh2fyyZZFVtPQdpV5jX36CySQ + title: AvgThroughputElementsPerSec by Version + tablerowinferencestreaming: + write: + folder: 106 + cost: + - id: WcSwy6KG8JQdYBVspDjbXqQgynrNWdxR + title: RunTime and EstimatedCost + date: + - id: c7h74dqjc4J7r4cGWFZwmzZK4WDBpRgt + title: AvgThroughputBytesPerSec by Date + - id: 7bznmBXqc2fwsQrGMfVWXZgNsyyq7Tdx + title: AvgThroughputElementsPerSec by Date + version: + - id: Y4464V8V4ngydnypmrrTVDmrYQsfb8mz + title: AvgThroughputBytesPerSec by Version + - id: P7wKZy6tQFWbbDfm4HzfCJnsQrVgfGsJ + title: AvgThroughputElementsPerSec by Version From 0ce6ad5112b9361993df613a49ae7c16cf244978 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Thu, 19 Feb 2026 20:38:10 +0200 Subject: [PATCH 2/5] Fix formatter, lint, and test for table row inference --- .../examples/inference/table_row_inference.py | 32 +++---- .../inference/table_row_inference_batch.py | 31 ++++--- .../inference/table_row_inference_test.py | 90 ++++++++++--------- .../inference/table_row_inference_utils.py | 57 ++++++------ .../table_row_inference_benchmark.py | 5 +- .../load_tests/dataflow_cost_benchmark.py | 16 ++-- 6 files changed, 116 insertions(+), 115 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference.py b/sdks/python/apache_beam/examples/inference/table_row_inference.py index 3608a0ae8a39..b75459595f2b 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference.py @@ -15,12 +15,12 @@ # limitations under the License. # -"""A pipeline that uses RunInference to perform inference on continuous table rows. +"""A pipeline that uses RunInference to perform inference on table rows. -This pipeline demonstrates ML Pipelines #18: handling continuous new table rows -with RunInference using table input models. It reads structured data (table rows) -from a streaming source, performs inference while preserving the table schema, -and writes results to a table output. +This pipeline demonstrates ML Pipelines #18: handling continuous new table +rows with RunInference using table input models. It reads structured data +(table rows) from a streaming source, performs inference while preserving +the table schema, and writes results to a table output. The pipeline supports both streaming and batch modes: - Streaming: Reads from Pub/Sub, applies windowing, writes via streaming inserts @@ -96,7 +96,8 @@ def run_inference( self, batch: list[beam.Row], model: Any, - inference_args: Optional[dict[str, Any]] = None) -> Iterable[PredictionResult]: + inference_args: Optional[dict[str, Any]] = None + ) -> Iterable[PredictionResult]: """Run inference on a batch of beam.Row objects. Args: @@ -218,11 +219,9 @@ def parse_known_args(argv): parser.add_argument( '--output_table', help='BigQuery output table (format: PROJECT:DATASET.TABLE)') + parser.add_argument('--model_path', help='Path to saved model file') parser.add_argument( - '--model_path', help='Path to saved model file') - parser.add_argument( - '--feature_columns', - help='Comma-separated list of feature column names') + '--feature_columns', help='Comma-separated list of feature column names') parser.add_argument( '--window_size_sec', type=int, @@ -282,8 +281,8 @@ def run( pipeline | 'ReadFromPubSub' >> beam.io.ReadFromPubSub(subscription=known_args.input_subscription) - | 'ParseToTableRows' >> beam.Map( - lambda msg: parse_json_to_table_row(msg, feature_columns)) + | 'ParseToTableRows' >> + beam.Map(lambda msg: parse_json_to_table_row(msg, feature_columns)) | 'WindowedData' >> beam.WindowInto( beam.window.FixedWindows(known_args.window_size_sec), trigger=beam.trigger.AfterProcessingTime( @@ -295,13 +294,11 @@ def run( read_lines = ( pipeline | 'ReadFromFile' >> beam.io.ReadFromText(known_args.input_file)) - expand_factor = getattr( - known_args, 'input_expand_factor', 1) or 1 + expand_factor = getattr(known_args, 'input_expand_factor', 1) or 1 if expand_factor > 1: read_lines = ( read_lines - | 'ExpandInput' >> beam.FlatMap( - lambda line: [line] * expand_factor)) + | 'ExpandInput' >> beam.FlatMap(lambda line: [line] * expand_factor)) input_data = ( read_lines | 'ParseToTableRows' >> beam.Map( @@ -310,8 +307,7 @@ def run( write_method = beam.io.WriteToBigQuery.Method.FILE_LOADS write_disposition = ( - beam.io.BigQueryDisposition.WRITE_APPEND - if known_args.mode == 'streaming' + beam.io.BigQueryDisposition.WRITE_APPEND if known_args.mode == 'streaming' else beam.io.BigQueryDisposition.WRITE_TRUNCATE) _ = ( input_data diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py b/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py index 57ded48db5bf..68d42f4a9594 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py @@ -89,8 +89,8 @@ def __init__(self, model_uri: str, feature_columns: list[str]): super().__init__(model_uri=model_uri) self.feature_columns = feature_columns logging.info( - f'Initialized BatchTableRowModelHandler with features: {feature_columns}' - ) + 'Initialized BatchTableRowModelHandler with features: %s', + feature_columns) def run_inference( self, @@ -160,7 +160,8 @@ def process( yield output -def parse_jsonl_line(line: str, schema_fields: list[str]) -> tuple[str, beam.Row]: +def parse_jsonl_line(line: str, + schema_fields: list[str]) -> tuple[str, beam.Row]: """Parse a JSONL line to (key, beam.Row) format. Args: @@ -178,7 +179,8 @@ def parse_jsonl_line(line: str, schema_fields: list[str]) -> tuple[str, beam.Row for field in schema_fields: if field in data: value = data[field] - row_fields[field] = float(value) if isinstance(value, (int, float)) else value + row_fields[field] = float(value) if isinstance( + value, (int, float)) else value return row_id, beam.Row(**row_fields) @@ -228,20 +230,20 @@ def run_batch_inference( model_handler = BatchTableRowModelHandler( model_uri=model_path, feature_columns=feature_columns) - logging.info(f'Starting batch inference pipeline') - logging.info(f' Input: {input_file}') - logging.info(f' Model: {model_path}') - logging.info(f' Features: {feature_columns}') + logging.info('Starting batch inference pipeline') + logging.info(' Input: %s', input_file) + logging.info(' Model: %s', model_path) + logging.info(' Features: %s', feature_columns) logging.info( - f' Output: {output_table if output_table else output_file}') + ' Output: %s', output_table if output_table else output_file) with beam.Pipeline(options=pipeline_options) as pipeline: input_data = ( pipeline | 'ReadInputFile' >> beam.io.ReadFromText(input_file) - | 'ParseToRows' >> beam.Map( - lambda line: parse_jsonl_line(line, feature_columns))) + | 'ParseToRows' >> + beam.Map(lambda line: parse_jsonl_line(line, feature_columns))) predictions = ( input_data @@ -290,7 +292,8 @@ def main(argv=None): parser.add_argument( '--feature_columns', required=True, - help='Comma-separated list of feature column names to extract from input rows.' + help= + 'Comma-separated list of feature column names to extract from input rows.' ) parser.add_argument( @@ -306,7 +309,9 @@ def main(argv=None): if not known_args.output_table and not known_args.output_file: parser.error('Must specify either --output_table or --output_file') - feature_columns = [col.strip() for col in known_args.feature_columns.split(',')] + feature_columns = [ + col.strip() for col in known_args.feature_columns.split(',') + ] pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = True diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py index 205ee34deabd..a164a0951d9c 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py @@ -28,13 +28,29 @@ from apache_beam.examples.inference.table_row_inference import FormatTableOutput from apache_beam.examples.inference.table_row_inference import TableRowModelHandler from apache_beam.examples.inference.table_row_inference import build_output_schema -from apache_beam.examples.inference.table_row_inference import parse_json_to_table_row +from apache_beam.examples.inference.table_row_inference import ( + parse_json_to_table_row) from apache_beam.ml.inference.base import KeyedModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that -from apache_beam.testing.util import equal_to + +# Module-level matcher for assert_that (must be picklable; no closure over self). +REQUIRED_OUTPUT_KEYS = ( + 'row_key', 'prediction', 'input_feature1', 'input_feature2', + 'input_feature3') + + +def _assert_table_inference_outputs(outputs): + """Asserts pipeline output has expected structure. Used in assert_that.""" + if len(outputs) != 2: + raise AssertionError(f'Expected 2 outputs, got {len(outputs)}') + for output in outputs: + for key in REQUIRED_OUTPUT_KEYS: + if key not in output: + raise AssertionError(f'Missing key {key!r} in output {output}') + try: from sklearn.linear_model import LinearRegression @@ -54,26 +70,23 @@ class TableRowInferenceTest(unittest.TestCase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() self.model_path = os.path.join(self.tmp_dir, 'test_model.pkl') - + model = LinearRegression() X_train = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) y_train = np.array([6, 15, 24]) model.fit(X_train, y_train) - + with open(self.model_path, 'wb') as f: pickle.dump(model, f) def test_parse_json_to_table_row(self): json_data = json.dumps({ - 'id': 'test_1', - 'feature1': 1.0, - 'feature2': 2.0, - 'feature3': 3.0 + 'id': 'test_1', 'feature1': 1.0, 'feature2': 2.0, 'feature3': 3.0 }).encode('utf-8') - + key, row = parse_json_to_table_row( json_data, schema_fields=['feature1', 'feature2', 'feature3']) - + self.assertEqual(key, 'test_1') self.assertEqual(row.feature1, 1.0) self.assertEqual(row.feature2, 2.0) @@ -82,28 +95,32 @@ def test_parse_json_to_table_row(self): def test_build_output_schema(self): feature_cols = ['feature1', 'feature2', 'feature3'] schema = build_output_schema(feature_cols) - + expected_fields = [ - 'row_key', 'prediction', 'model_id', 'input_feature1', - 'input_feature2', 'input_feature3' + 'row_key', + 'prediction', + 'model_id', + 'input_feature1', + 'input_feature2', + 'input_feature3' ] - + for field in expected_fields: self.assertIn(field, schema) def test_table_row_model_handler(self): model_handler = TableRowModelHandler( model_uri=self.model_path, feature_columns=['f1', 'f2', 'f3']) - + model = model_handler.load_model() - + test_rows = [ beam.Row(f1=1.0, f2=2.0, f3=3.0), beam.Row(f1=4.0, f2=5.0, f3=6.0), ] - + results = list(model_handler.run_inference(test_rows, model)) - + self.assertEqual(len(results), 2) self.assertIsInstance(results[0], PredictionResult) self.assertEqual(results[0].example, test_rows[0]) @@ -113,16 +130,16 @@ def test_format_table_output(self): row = beam.Row(feature1=1.0, feature2=2.0, feature3=3.0) prediction_result = PredictionResult( example=row, inference=6.0, model_id='test_model') - + keyed_result = ('test_key', prediction_result) - + feature_columns = ['feature1', 'feature2', 'feature3'] formatter = FormatTableOutput(feature_columns=feature_columns) outputs = list(formatter.process(keyed_result)) - + self.assertEqual(len(outputs), 1) output = outputs[0] - + self.assertEqual(output['row_key'], 'test_key') self.assertEqual(output['prediction'], 6.0) self.assertEqual(output['model_id'], 'test_model') @@ -139,11 +156,11 @@ def test_pipeline_integration(self): 'id': 'row_2', 'feature1': 4.0, 'feature2': 5.0, 'feature3': 6.0 }), ] - + feature_columns = ['feature1', 'feature2', 'feature3'] model_handler = TableRowModelHandler( model_uri=self.model_path, feature_columns=feature_columns) - + with TestPipeline() as p: input_data = ( p @@ -151,40 +168,25 @@ def test_pipeline_integration(self): | beam.Map( lambda line: parse_json_to_table_row( line.encode('utf-8'), feature_columns))) - + predictions = ( input_data | RunInference(KeyedModelHandler(model_handler)) | beam.ParDo(FormatTableOutput(feature_columns=feature_columns))) - - def check_outputs(outputs): - self.assertEqual(len(outputs), 2) - - for output in outputs: - self.assertIn('row_key', output) - self.assertIn('prediction', output) - self.assertIn('input_feature1', output) - self.assertIn('input_feature2', output) - self.assertIn('input_feature3', output) - - assert_that(predictions, check_outputs) + + assert_that(predictions, _assert_table_inference_outputs) class TableRowInferenceNoSklearnTest(unittest.TestCase): """Tests that don't require sklearn.""" def test_parse_json_without_schema(self): json_data = json.dumps({'id': 'test', 'value': 123}).encode('utf-8') - + key, row = parse_json_to_table_row(json_data) - + self.assertEqual(key, 'test') self.assertTrue(hasattr(row, 'value')) if __name__ == '__main__': unittest.main() - - - - - diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py b/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py index 0c5f9814ca2f..449163d74a14 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_utils.py @@ -64,11 +64,11 @@ def create_sample_model(output_path: str, num_features: int = 3): with open(output_path, 'wb') as f: pickle.dump(model, f) - logging.info(f'Sample model saved to {output_path}') + logging.info('Sample model saved to %s', output_path) -def generate_sample_data( - num_rows: int = 100, num_features: int = 3) -> list[dict]: +def generate_sample_data(num_rows: int = 100, + num_features: int = 3) -> list[dict]: """Generate sample table row data for testing. Args: @@ -104,7 +104,7 @@ def write_data_to_file(data: list[dict], output_path: str): for row in data: f.write(json.dumps(row) + '\n') - logging.info(f'Wrote {len(data)} rows to {output_path}') + logging.info('Wrote %d rows to %s', len(data), output_path) def publish_to_pubsub( @@ -130,15 +130,15 @@ def publish_to_pubsub( for i, row in enumerate(data): message = json.dumps(row).encode('utf-8') - future = publisher.publish(topic_path, message) + publisher.publish(topic_path, message) if (i + 1) % 100 == 0: - logging.info(f'Published {i + 1} messages') + logging.info('Published %d messages', i + 1) if delay > 0: time.sleep(delay) - logging.info(f'Published {len(data)} messages to {topic_path}') + logging.info('Published %d messages to %s', len(data), topic_path) def ensure_pubsub_topic(project: str, topic_name: str) -> str: @@ -159,10 +159,10 @@ def ensure_pubsub_topic(project: str, topic_name: str) -> str: try: publisher.get_topic(request={'topic': topic_path}) - logging.info(f'Topic {topic_name} already exists') + logging.info('Topic %s already exists', topic_name) except Exception: publisher.create_topic(name=topic_path) - logging.info(f'Created topic {topic_name}') + logging.info('Created topic %s', topic_name) return topic_path @@ -190,11 +190,10 @@ def ensure_pubsub_subscription( try: subscriber.get_subscription(request={'subscription': subscription_path}) - logging.info(f'Subscription {subscription_name} already exists') + logging.info('Subscription %s already exists', subscription_name) except Exception: - subscriber.create_subscription( - name=subscription_path, topic=topic_path) - logging.info(f'Created subscription {subscription_name}') + subscriber.create_subscription(name=subscription_path, topic=topic_path) + logging.info('Created subscription %s', subscription_name) return subscription_path @@ -217,17 +216,18 @@ def cleanup_pubsub_resources( if subscription_name: subscription_path = subscriber.subscription_path(project, subscription_name) try: - subscriber.delete_subscription(request={'subscription': subscription_path}) - logging.info(f'Deleted subscription {subscription_name}') + subscriber.delete_subscription( + request={'subscription': subscription_path}) + logging.info('Deleted subscription %s', subscription_name) except Exception as e: - logging.warning(f'Failed to delete subscription: {e}') + logging.warning('Failed to delete subscription: %s', e) topic_path = publisher.topic_path(project, topic_name) try: publisher.delete_topic(request={'topic': topic_path}) - logging.info(f'Deleted topic {topic_name}') + logging.info('Deleted topic %s', topic_name) except Exception as e: - logging.warning(f'Failed to delete topic: {e}') + logging.warning('Failed to delete topic: %s', e) if __name__ == '__main__': @@ -240,8 +240,12 @@ def cleanup_pubsub_resources( '--action', required=True, choices=[ - 'create_model', 'generate_data', 'publish_data', 'create_topic', - 'create_subscription', 'cleanup' + 'create_model', + 'generate_data', + 'publish_data', + 'create_topic', + 'create_subscription', + 'cleanup' ], help='Action to perform') parser.add_argument('--project', help='GCP project ID') @@ -251,14 +255,9 @@ def cleanup_pubsub_resources( parser.add_argument( '--num_rows', type=int, default=100, help='Number of rows to generate') parser.add_argument( - '--num_features', - type=int, - default=3, - help='Number of features per row') + '--num_features', type=int, default=3, help='Number of features per row') parser.add_argument( - '--rate_limit', - type=float, - help='Rate limit for publishing (rows/sec)') + '--rate_limit', type=float, help='Rate limit for publishing (rows/sec)') args = parser.parse_args() @@ -287,8 +286,8 @@ def cleanup_pubsub_resources( elif args.action == 'create_subscription': if not args.project or not args.topic or not args.subscription: raise ValueError( - '--project, --topic, and --subscription required for create_subscription' - ) + '--project, --topic, and --subscription required for ' + 'create_subscription') ensure_pubsub_subscription(args.project, args.topic, args.subscription) elif args.action == 'cleanup': diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py index 5ddf3cf40306..bca5263b9f9d 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py @@ -41,7 +41,6 @@ class TableRowInferenceOptions( DebugOptions, SetupOptions, ): - @classmethod def _add_argparse_args(cls, parser): parser.add_argument('--mode', default='batch') @@ -73,8 +72,8 @@ def __init__(self): metrics_namespace=self.metrics_namespace, is_streaming=False, pcollection='RunInference/BeamML_RunInference_Postprocess-0.out0') - self.is_streaming = ( - (self.pipeline.get_option('mode') or 'batch') == 'streaming') + self.is_streaming = ((self.pipeline.get_option('mode') or + 'batch') == 'streaming') if self.is_streaming: self.subscription = self.pipeline.get_option('input_subscription') diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py index ba621fb6fdf6..19ffc7674526 100644 --- a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py @@ -177,7 +177,9 @@ def _get_worker_time_interval( logging.warning( 'Could not determine worker time interval. ' 'start_time=%s, end_time=%s, total messages=%d', - start_time, end_time, len(all_messages)) + start_time, + end_time, + len(all_messages)) return start_time, end_time def _get_throughput_metrics( @@ -226,10 +228,7 @@ def _get_throughput_metrics( return metrics def _get_streaming_throughput_metrics( - self, - project: str, - start_time: str, - end_time: str) -> dict[str, float]: + self, project: str, start_time: str, end_time: str) -> dict[str, float]: if not self.subscription: return {'AvgThroughputBytes': 0.0, 'AvgThroughputElements': 0.0} @@ -261,8 +260,7 @@ def _get_streaming_throughput_metrics( metrics = {} for key, req in requests.items(): - time_series = list( - self.monitoring_client.list_time_series(request=req)) + time_series = list(self.monitoring_client.list_time_series(request=req)) values = [ point.value.double_value for series in time_series for point in series.points @@ -299,7 +297,9 @@ def _get_additional_metrics(self, logging.warning( 'No throughput data for PCollection "%s". Check Dataflow job %s ' 'graph for actual PCollection names (Runner v2 may use different ' - 'naming).', self.pcollection, job_id) + 'naming).', + self.pcollection, + job_id) return { **throughput_metrics, "JobRuntimeSeconds": runtime_seconds, From 8be7201829d9c4a477edb79f15127b9e6e415d23 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Fri, 20 Feb 2026 14:04:13 +0200 Subject: [PATCH 3/5] Fix formatting --- .../examples/inference/table_row_inference_batch.py | 3 +-- .../examples/inference/table_row_inference_test.py | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py b/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py index 68d42f4a9594..f3a60bada51f 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_batch.py @@ -234,8 +234,7 @@ def run_batch_inference( logging.info(' Input: %s', input_file) logging.info(' Model: %s', model_path) logging.info(' Features: %s', feature_columns) - logging.info( - ' Output: %s', output_table if output_table else output_file) + logging.info(' Output: %s', output_table if output_table else output_file) with beam.Pipeline(options=pipeline_options) as pipeline: diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py index a164a0951d9c..2672af531c44 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py @@ -38,7 +38,10 @@ # Module-level matcher for assert_that (must be picklable; no closure over self). REQUIRED_OUTPUT_KEYS = ( - 'row_key', 'prediction', 'input_feature1', 'input_feature2', + 'row_key', + 'prediction', + 'input_feature1', + 'input_feature2', 'input_feature3') From d261aa16598ed43178a9d886578f0b5b017bef63 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Fri, 20 Feb 2026 17:02:15 +0200 Subject: [PATCH 4/5] Fix pylint and RunInference singleton test --- .../examples/inference/table_row_inference_test.py | 3 ++- sdks/python/apache_beam/ml/inference/base_test.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py index 2672af531c44..e0a0c09a7c6a 100644 --- a/sdks/python/apache_beam/examples/inference/table_row_inference_test.py +++ b/sdks/python/apache_beam/examples/inference/table_row_inference_test.py @@ -36,7 +36,8 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that -# Module-level matcher for assert_that (must be picklable; no closure over self). +# Module-level matcher for assert_that (must be picklable; no closure over +# self). REQUIRED_OUTPUT_KEYS = ( 'row_key', 'prediction', diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 381bf5456604..d3dd830fd88c 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1155,10 +1155,9 @@ def test_run_inference_with_iterable_side_input(self): FakeModelHandler(), model_metadata_pcoll=side_input)) test_pipeline.run() - self.assertTrue( - 'PCollection of size 2 with more than one element accessed as a ' - 'singleton view. First two elements encountered are' in str( - e.exception)) + msg = str(e.exception) + self.assertIn('singleton', msg, msg='Expected singleton view error') + self.assertIn('more than one', msg, msg='Expected multiple-elements error') def test_run_inference_with_iterable_side_input_multi_process_shared(self): test_pipeline = TestPipeline() @@ -1180,10 +1179,9 @@ def test_run_inference_with_iterable_side_input_multi_process_shared(self): model_metadata_pcoll=side_input)) test_pipeline.run() - self.assertTrue( - 'PCollection of size 2 with more than one element accessed as a ' - 'singleton view. First two elements encountered are' in str( - e.exception)) + msg = str(e.exception) + self.assertIn('singleton', msg, msg='Expected singleton view error') + self.assertIn('more than one', msg, msg='Expected multiple-elements error') def test_run_inference_empty_side_input(self): model_handler = FakeModelHandlerReturnsPredictionResult() From 331aa6470e01a80167427e2580cdca9b857aa017 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Fri, 20 Feb 2026 20:03:39 +0200 Subject: [PATCH 5/5] Fix Pylint in dataflow_cost_benchmark and singleton assertion in base_test --- .../testing/load_tests/dataflow_cost_benchmark.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py index 19ffc7674526..3750dcb5ba38 100644 --- a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py @@ -91,7 +91,7 @@ def run(self) -> None: logging.info(self.extra_metrics) self._metrics_monitor.publish_metrics(self.result, self.extra_metrics) - except Exception as e: + except Exception: raise finally: self.cleanup() @@ -188,7 +188,8 @@ def _get_throughput_metrics( job_id: str, start_time: str, end_time: str, - pcollection_name: Optional[str] = None) -> dict[str, float]: + pcollection_name: Optional[str] = None, + ) -> dict[str, float]: """Query Cloud Monitoring for per-PCollection throughput.""" name = pcollection_name if pcollection_name is not None else self.pcollection interval = monitoring_v3.TimeInterval(