diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83e0..6fb287b980 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -106,8 +106,11 @@ async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: return event - event = self._trim_temp_delta_state(event) + # Update session state with ALL keys (including temp:) so they're accessible + # during callbacks within the same invocation self._update_session_state(session, event) + # Trim temp: keys from the event before persisting to avoid storing them + event = self._trim_temp_delta_state(event) session.events.append(event) return event @@ -127,7 +130,4 @@ def _update_session_state(self, session: Session, event: Event) -> None: """Updates the session state based on the event.""" if not event.actions or not event.actions.state_delta: return - for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue - session.state.update({key: value}) + session.state.update(event.actions.state_delta) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ca7eb37533..f01c9c9802 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -1237,5 +1237,162 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent( assert "actual_name" in runner._app_name_alignment_hint +@pytest.mark.asyncio +async def test_temp_state_accessible_in_callbacks_but_not_persisted(): + """Tests that temp: state variables are accessible during lifecycle callbacks + but not persisted in the session.""" + + # Track what state was seen during callbacks + state_seen_in_before_agent = {} + state_seen_in_after_agent = {} + + class StateAccessPlugin(BasePlugin): + """Plugin that accesses state during callbacks.""" + + async def before_agent_callback(self, *, agent, callback_context): + # Set a temp state variable + callback_context.state["temp:test_key"] = "test_value" + callback_context.state["normal_key"] = "normal_value" + + # Verify we can read it back immediately + state_seen_in_before_agent["temp:test_key"] = callback_context.state.get( + "temp:test_key" + ) + state_seen_in_before_agent["normal_key"] = callback_context.state.get( + "normal_key" + ) + return None + + async def after_agent_callback(self, *, agent, callback_context): + # Verify temp state is still accessible during the same invocation + state_seen_in_after_agent["temp:test_key"] = callback_context.state.get( + "temp:test_key" + ) + state_seen_in_after_agent["normal_key"] = callback_context.state.get( + "normal_key" + ) + return None + + # Setup + session_service = InMemorySessionService() + plugin = StateAccessPlugin(name="state_access") + + agent = MockAgent(name="test_agent") + runner = Runner( + app_name=TEST_APP_ID, + agent=agent, + session_service=session_service, + plugins=[plugin], + auto_create_session=True, + ) + + # Run the agent + async for _ in runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="test message")] + ), + ): + pass + + # Verify temp state was accessible during callbacks + assert state_seen_in_before_agent["temp:test_key"] == "test_value" + assert state_seen_in_before_agent["normal_key"] == "normal_value" + assert state_seen_in_after_agent["temp:test_key"] == "test_value" + assert state_seen_in_after_agent["normal_key"] == "normal_value" + + # Verify temp state is NOT persisted in the session + session = await session_service.get_session( + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + ) + + # Normal state should be persisted + assert session.state.get("normal_key") == "normal_value" + + # Temp state should NOT be persisted + assert "temp:test_key" not in session.state + + # Verify temp state is also not in any event's state_delta + for event in session.events: + if event.actions and event.actions.state_delta: + assert "temp:test_key" not in event.actions.state_delta + + +@pytest.mark.asyncio +async def test_temp_state_from_state_delta_accessible_in_callbacks(): + """Tests that temp: state set via run_async state_delta parameter is + accessible during lifecycle callbacks but not persisted.""" + + # Track what state was seen during callbacks + state_seen_in_before_agent = {} + + class StateAccessPlugin(BasePlugin): + """Plugin that accesses state during callbacks.""" + + async def before_agent_callback(self, *, agent, callback_context): + # Check if temp state from state_delta is accessible + state_seen_in_before_agent["temp:from_run_async"] = ( + callback_context.state.get("temp:from_run_async") + ) + state_seen_in_before_agent["normal:from_run_async"] = ( + callback_context.state.get("normal:from_run_async") + ) + return None + + # Setup + session_service = InMemorySessionService() + plugin = StateAccessPlugin(name="state_access") + + agent = MockAgent(name="test_agent") + runner = Runner( + app_name=TEST_APP_ID, + agent=agent, + session_service=session_service, + plugins=[plugin], + auto_create_session=True, + ) + + # Run the agent with state_delta containing both temp and normal keys + async for _ in runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="test message")] + ), + state_delta={ + "temp:from_run_async": "temp_value", + "normal:from_run_async": "normal_value", + }, + ): + pass + + # Verify temp state from state_delta WAS accessible during callbacks + assert ( + state_seen_in_before_agent["temp:from_run_async"] == "temp_value" + ), "temp: state from state_delta should be accessible in callbacks" + assert state_seen_in_before_agent["normal:from_run_async"] == "normal_value" + + # Verify temp state is NOT persisted in the session + session = await session_service.get_session( + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + ) + + # Normal state should be persisted + assert session.state.get("normal:from_run_async") == "normal_value" + + # Temp state should NOT be persisted + assert "temp:from_run_async" not in session.state + + # Verify temp state is also not in any event's state_delta + for event in session.events: + if event.actions and event.actions.state_delta: + assert "temp:from_run_async" not in event.actions.state_delta + + if __name__ == "__main__": pytest.main([__file__])