diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 5de34a6c2..095ad6aea 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -79,6 +79,8 @@ class BedrockConfig(TypedDict, total=False): cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. cache_tools: Cache point type for tools + cache_tools_ttl: Time-to-live for the tools cache point. Supported values are "5m" and "1h". + If None, the provider default (5 minutes) is used. guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -105,6 +107,7 @@ class BedrockConfig(TypedDict, total=False): cache_prompt: str | None cache_config: CacheConfig | None cache_tools: str | None + cache_tools_ttl: Literal["5m", "1h"] | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -260,7 +263,18 @@ def _format_request( for tool_spec in tool_specs ], *( - [{"cachePoint": {"type": self.config["cache_tools"]}}] + [ + { + "cachePoint": { + "type": self.config["cache_tools"], + **( + {"ttl": self.config["cache_tools_ttl"]} + if self.config.get("cache_tools_ttl") + else {} + ), + } + } + ] if self.config.get("cache_tools") else [] ), @@ -338,11 +352,12 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} - def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: + def _inject_cache_point(self, messages: list[dict[str, Any]], ttl: str | None = None) -> None: """Inject a cache point at the end of the last user message. Args: messages: List of messages to inject cache point into (modified in place). + ttl: Optional TTL for the cache point (e.g., "5m", "1h"). """ if not messages: return @@ -362,7 +377,10 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: last_user_idx = msg_idx if last_user_idx is not None and messages[last_user_idx].get("content"): - messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + cache_point: dict[str, Any] = {"type": "default"} + if ttl: + cache_point["ttl"] = ttl + messages[last_user_idx]["content"].append({"cachePoint": cache_point}) logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: @@ -471,7 +489,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: self.config.get("model_id"), ) if strategy == "anthropic": - self._inject_cache_point(cleaned_messages) + self._inject_cache_point(cleaned_messages, ttl=cache_config.ttl) return cleaned_messages @@ -515,7 +533,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An """ # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html if "cachePoint" in content: - return {"cachePoint": {"type": content["cachePoint"]["type"]}} + cache_point: dict[str, Any] = {"type": content["cachePoint"]["type"]} + ttl = content["cachePoint"].get("ttl") + if ttl: + cache_point["ttl"] = ttl + return {"cachePoint": cache_point} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html if "document" in content: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index f084d24d5..79365b357 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -30,9 +30,13 @@ class CacheConfig: strategy: Caching strategy to use. - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check + ttl: Time-to-live for cache points. Supported values depend on the model provider. + For Bedrock, supported values are "5m" (5 minutes) and "1h" (1 hour). + If None, the provider default (typically 5 minutes) is used. """ strategy: Literal["auto", "anthropic"] = "auto" + ttl: Literal["5m", "1h"] | None = None class Model(abc.ABC): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 5f81efd24..4856d39b6 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2809,3 +2809,123 @@ def test_guardrail_latest_message_disabled_does_not_wrap(model): assert "text" in formatted assert "guardContent" not in formatted + + +def test_inject_cache_point_with_ttl(bedrock_client): + """Test that _inject_cache_point includes TTL when configured.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + model._inject_cache_point(cleaned_messages, ttl="1h") + + cache_point = cleaned_messages[2]["content"][-1] + assert "cachePoint" in cache_point + assert cache_point["cachePoint"]["type"] == "default" + assert cache_point["cachePoint"]["ttl"] == "1h" + + +def test_inject_cache_point_without_ttl_has_no_ttl_field(bedrock_client): + """Test that _inject_cache_point does not include TTL when not passed.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1] + assert "cachePoint" in cache_point + assert cache_point["cachePoint"]["type"] == "default" + assert "ttl" not in cache_point["cachePoint"] + + +def test_inject_cache_point_with_5m_ttl(bedrock_client): + """Test that _inject_cache_point includes 5m TTL when configured.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages, ttl="5m") + + cache_point = cleaned_messages[0]["content"][-1] + assert cache_point["cachePoint"]["ttl"] == "5m" + + +def test_format_bedrock_messages_passes_ttl_from_cache_config(bedrock_client): + """Test that _format_bedrock_messages passes cache_config.ttl to _inject_cache_point.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + cache_point = formatted[0]["content"][-1] + assert "cachePoint" in cache_point + assert cache_point["cachePoint"]["ttl"] == "1h" + + +def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spec): + """Test that cache_tools_ttl is included in the tools cache point.""" + model.update_config(cache_tools="default", cache_tools_ttl="1h") + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + tools = tru_request["toolConfig"]["tools"] + cache_point_block = tools[-1] + assert "cachePoint" in cache_point_block + assert cache_point_block["cachePoint"]["type"] == "default" + assert cache_point_block["cachePoint"]["ttl"] == "1h" + + +def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec): + """Test that no TTL is included in tools cache point when cache_tools_ttl is not set.""" + model.update_config(cache_tools="default") + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + tools = tru_request["toolConfig"]["tools"] + cache_point_block = tools[-1] + assert "cachePoint" in cache_point_block + assert cache_point_block["cachePoint"]["type"] == "default" + assert "ttl" not in cache_point_block["cachePoint"] + + +def test_format_bedrock_content_block_cache_point_with_ttl(model): + """Test that _format_request_message_content preserves TTL on cachePoint blocks.""" + content = {"cachePoint": {"type": "default", "ttl": "1h"}} + + result = model._format_request_message_content(content) + + assert result == {"cachePoint": {"type": "default", "ttl": "1h"}} + + +def test_format_bedrock_content_block_cache_point_without_ttl(model): + """Test that _format_request_message_content does not add TTL when not present.""" + content = {"cachePoint": {"type": "default"}} + + result = model._format_request_message_content(content) + + assert result == {"cachePoint": {"type": "default"}} + assert "ttl" not in result["cachePoint"]