From 247c91313a6d40630578c94c4d30d4e158fc951f Mon Sep 17 00:00:00 2001 From: Ashish Shubham Date: Sun, 15 Feb 2026 15:12:03 -0800 Subject: [PATCH 1/3] Keep the temp: state keys in the in-memory context session for downstream consumers --- src/google/adk/sessions/base_session_service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83e0..11c3f44590 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -128,6 +128,4 @@ def _update_session_state(self, session: Session, event: Event) -> None: if not event.actions or not event.actions.state_delta: return for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue session.state.update({key: value}) From a0f8f38481ded64cfdf5593f49aa67a9671e9dbb Mon Sep 17 00:00:00 2001 From: Ashish Shubham Date: Sun, 15 Feb 2026 15:25:43 -0800 Subject: [PATCH 2/3] Add UT --- tests/unittests/test_runners.py | 85 +++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ca7eb37533..ff4bd673e2 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -1237,5 +1237,90 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent( assert "actual_name" in runner._app_name_alignment_hint +@pytest.mark.asyncio +async def test_temp_state_accessible_in_callbacks_but_not_persisted(): + """Tests that temp: state variables are accessible during lifecycle callbacks + but not persisted in the session.""" + + # Track what state was seen during callbacks + state_seen_in_before_agent = {} + state_seen_in_after_agent = {} + + class StateAccessPlugin(BasePlugin): + """Plugin that accesses state during callbacks.""" + + async def before_agent_callback(self, *, agent, callback_context): + # Set a temp state variable + callback_context.state["temp:test_key"] = "test_value" + callback_context.state["normal_key"] = "normal_value" + + # Verify we can read it back immediately + state_seen_in_before_agent["temp:test_key"] = callback_context.state.get( + "temp:test_key" + ) + state_seen_in_before_agent["normal_key"] = callback_context.state.get( + "normal_key" + ) + return None + + async def after_agent_callback(self, *, agent, callback_context): + # Verify temp state is still accessible during the same invocation + state_seen_in_after_agent["temp:test_key"] = callback_context.state.get( + "temp:test_key" + ) + state_seen_in_after_agent["normal_key"] = callback_context.state.get( + "normal_key" + ) + return None + + # Setup + session_service = InMemorySessionService() + plugin = StateAccessPlugin(name="state_access") + + agent = MockAgent(name="test_agent") + runner = Runner( + app_name=TEST_APP_ID, + agent=agent, + session_service=session_service, + plugins=[plugin], + auto_create_session=True, + ) + + # Run the agent + events = [] + async for event in runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="test message")] + ), + ): + events.append(event) + + # Verify temp state was accessible during callbacks + assert state_seen_in_before_agent["temp:test_key"] == "test_value" + assert state_seen_in_before_agent["normal_key"] == "normal_value" + assert state_seen_in_after_agent["temp:test_key"] == "test_value" + assert state_seen_in_after_agent["normal_key"] == "normal_value" + + # Verify temp state is NOT persisted in the session + session = await session_service.get_session( + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + ) + + # Normal state should be persisted + assert session.state.get("normal_key") == "normal_value" + + # Temp state should NOT be persisted + assert "temp:test_key" not in session.state + + # Verify temp state is also not in any event's state_delta + for event in session.events: + if event.actions and event.actions.state_delta: + assert "temp:test_key" not in event.actions.state_delta + + if __name__ == "__main__": pytest.main([__file__]) From f6a234797ff3123cac5157307da196cd16b40d15 Mon Sep 17 00:00:00 2001 From: Ashish Shubham Date: Sun, 15 Feb 2026 15:34:54 -0800 Subject: [PATCH 3/3] Address review comments --- .../adk/sessions/base_session_service.py | 8 +- tests/unittests/test_runners.py | 78 ++++++++++++++++++- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 11c3f44590..6fb287b980 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -106,8 +106,11 @@ async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: return event - event = self._trim_temp_delta_state(event) + # Update session state with ALL keys (including temp:) so they're accessible + # during callbacks within the same invocation self._update_session_state(session, event) + # Trim temp: keys from the event before persisting to avoid storing them + event = self._trim_temp_delta_state(event) session.events.append(event) return event @@ -127,5 +130,4 @@ def _update_session_state(self, session: Session, event: Event) -> None: """Updates the session state based on the event.""" if not event.actions or not event.actions.state_delta: return - for key, value in event.actions.state_delta.items(): - session.state.update({key: value}) + session.state.update(event.actions.state_delta) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ff4bd673e2..f01c9c9802 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -1287,15 +1287,14 @@ async def after_agent_callback(self, *, agent, callback_context): ) # Run the agent - events = [] - async for event in runner.run_async( + async for _ in runner.run_async( user_id=TEST_USER_ID, session_id=TEST_SESSION_ID, new_message=types.Content( role="user", parts=[types.Part(text="test message")] ), ): - events.append(event) + pass # Verify temp state was accessible during callbacks assert state_seen_in_before_agent["temp:test_key"] == "test_value" @@ -1322,5 +1321,78 @@ async def after_agent_callback(self, *, agent, callback_context): assert "temp:test_key" not in event.actions.state_delta +@pytest.mark.asyncio +async def test_temp_state_from_state_delta_accessible_in_callbacks(): + """Tests that temp: state set via run_async state_delta parameter is + accessible during lifecycle callbacks but not persisted.""" + + # Track what state was seen during callbacks + state_seen_in_before_agent = {} + + class StateAccessPlugin(BasePlugin): + """Plugin that accesses state during callbacks.""" + + async def before_agent_callback(self, *, agent, callback_context): + # Check if temp state from state_delta is accessible + state_seen_in_before_agent["temp:from_run_async"] = ( + callback_context.state.get("temp:from_run_async") + ) + state_seen_in_before_agent["normal:from_run_async"] = ( + callback_context.state.get("normal:from_run_async") + ) + return None + + # Setup + session_service = InMemorySessionService() + plugin = StateAccessPlugin(name="state_access") + + agent = MockAgent(name="test_agent") + runner = Runner( + app_name=TEST_APP_ID, + agent=agent, + session_service=session_service, + plugins=[plugin], + auto_create_session=True, + ) + + # Run the agent with state_delta containing both temp and normal keys + async for _ in runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="test message")] + ), + state_delta={ + "temp:from_run_async": "temp_value", + "normal:from_run_async": "normal_value", + }, + ): + pass + + # Verify temp state from state_delta WAS accessible during callbacks + assert ( + state_seen_in_before_agent["temp:from_run_async"] == "temp_value" + ), "temp: state from state_delta should be accessible in callbacks" + assert state_seen_in_before_agent["normal:from_run_async"] == "normal_value" + + # Verify temp state is NOT persisted in the session + session = await session_service.get_session( + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + ) + + # Normal state should be persisted + assert session.state.get("normal:from_run_async") == "normal_value" + + # Temp state should NOT be persisted + assert "temp:from_run_async" not in session.state + + # Verify temp state is also not in any event's state_delta + for event in session.events: + if event.actions and event.actions.state_delta: + assert "temp:from_run_async" not in event.actions.state_delta + + if __name__ == "__main__": pytest.main([__file__])