Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class BedrockConfig(TypedDict, total=False):
cache_prompt: Cache point type for the system prompt (deprecated, use cache_config)
cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching.
cache_tools: Cache point type for tools
cache_tools_ttl: Time-to-live for the tools cache point. Supported values are "5m" and "1h".
If None, the provider default (5 minutes) is used.
guardrail_id: ID of the guardrail to apply
guardrail_trace: Guardrail trace mode. Defaults to enabled.
guardrail_version: Version of the guardrail to apply
Expand All @@ -105,6 +107,7 @@ class BedrockConfig(TypedDict, total=False):
cache_prompt: str | None
cache_config: CacheConfig | None
cache_tools: str | None
cache_tools_ttl: Literal["5m", "1h"] | None
guardrail_id: str | None
guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None
guardrail_stream_processing_mode: Literal["sync", "async"] | None
Expand Down Expand Up @@ -260,7 +263,18 @@ def _format_request(
for tool_spec in tool_specs
],
*(
[{"cachePoint": {"type": self.config["cache_tools"]}}]
[
{
"cachePoint": {
"type": self.config["cache_tools"],
**(
{"ttl": self.config["cache_tools_ttl"]}
if self.config.get("cache_tools_ttl")
else {}
),
}
}
]
if self.config.get("cache_tools")
else []
),
Expand Down Expand Up @@ -338,11 +352,12 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict

return {"additionalModelRequestFields": additional_fields}

def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
def _inject_cache_point(self, messages: list[dict[str, Any]], ttl: str | None = None) -> None:
"""Inject a cache point at the end of the last user message.

Args:
messages: List of messages to inject cache point into (modified in place).
ttl: Optional TTL for the cache point (e.g., "5m", "1h").
"""
if not messages:
return
Expand All @@ -362,7 +377,10 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
last_user_idx = msg_idx

if last_user_idx is not None and messages[last_user_idx].get("content"):
messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}})
cache_point: dict[str, Any] = {"type": "default"}
if ttl:
cache_point["ttl"] = ttl
messages[last_user_idx]["content"].append({"cachePoint": cache_point})
logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx)

def _find_last_user_text_message_index(self, messages: Messages) -> int | None:
Expand Down Expand Up @@ -471,7 +489,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
self.config.get("model_id"),
)
if strategy == "anthropic":
self._inject_cache_point(cleaned_messages)
self._inject_cache_point(cleaned_messages, ttl=cache_config.ttl)

return cleaned_messages

Expand Down Expand Up @@ -515,7 +533,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
"""
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
if "cachePoint" in content:
return {"cachePoint": {"type": content["cachePoint"]["type"]}}
cache_point: dict[str, Any] = {"type": content["cachePoint"]["type"]}
ttl = content["cachePoint"].get("ttl")
if ttl:
cache_point["ttl"] = ttl
return {"cachePoint": cache_point}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html
if "document" in content:
Expand Down
4 changes: 4 additions & 0 deletions src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ class CacheConfig:
strategy: Caching strategy to use.
- "auto": Automatically detect model support and inject cachePoint to maximize cache coverage
- "anthropic": Inject cachePoint in Anthropic-compatible format without model support check
ttl: Time-to-live for cache points. Supported values depend on the model provider.
For Bedrock, supported values are "5m" (5 minutes) and "1h" (1 hour).
If None, the provider default (typically 5 minutes) is used.
"""

strategy: Literal["auto", "anthropic"] = "auto"
ttl: Literal["5m", "1h"] | None = None


class Model(abc.ABC):
Expand Down
120 changes: 120 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,3 +2809,123 @@ def test_guardrail_latest_message_disabled_does_not_wrap(model):

assert "text" in formatted
assert "guardContent" not in formatted


def test_inject_cache_point_with_ttl(bedrock_client):
"""Test that _inject_cache_point includes TTL when configured."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
cache_config=CacheConfig(strategy="auto", ttl="1h"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Hi there!"}]},
{"role": "user", "content": [{"text": "How are you?"}]},
]

model._inject_cache_point(cleaned_messages, ttl="1h")

cache_point = cleaned_messages[2]["content"][-1]
assert "cachePoint" in cache_point
assert cache_point["cachePoint"]["type"] == "default"
assert cache_point["cachePoint"]["ttl"] == "1h"


def test_inject_cache_point_without_ttl_has_no_ttl_field(bedrock_client):
"""Test that _inject_cache_point does not include TTL when not passed."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
cache_config=CacheConfig(strategy="auto"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(cleaned_messages)

cache_point = cleaned_messages[0]["content"][-1]
assert "cachePoint" in cache_point
assert cache_point["cachePoint"]["type"] == "default"
assert "ttl" not in cache_point["cachePoint"]


def test_inject_cache_point_with_5m_ttl(bedrock_client):
"""Test that _inject_cache_point includes 5m TTL when configured."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
cache_config=CacheConfig(strategy="auto", ttl="5m"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(cleaned_messages, ttl="5m")

cache_point = cleaned_messages[0]["content"][-1]
assert cache_point["cachePoint"]["ttl"] == "5m"


def test_format_bedrock_messages_passes_ttl_from_cache_config(bedrock_client):
"""Test that _format_bedrock_messages passes cache_config.ttl to _inject_cache_point."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
cache_config=CacheConfig(strategy="auto", ttl="1h"),
)

messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

formatted = model._format_bedrock_messages(messages)

cache_point = formatted[0]["content"][-1]
assert "cachePoint" in cache_point
assert cache_point["cachePoint"]["ttl"] == "1h"


def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spec):
"""Test that cache_tools_ttl is included in the tools cache point."""
model.update_config(cache_tools="default", cache_tools_ttl="1h")

tru_request = model._format_request(messages, tool_specs=[tool_spec])

tools = tru_request["toolConfig"]["tools"]
cache_point_block = tools[-1]
assert "cachePoint" in cache_point_block
assert cache_point_block["cachePoint"]["type"] == "default"
assert cache_point_block["cachePoint"]["ttl"] == "1h"


def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec):
"""Test that no TTL is included in tools cache point when cache_tools_ttl is not set."""
model.update_config(cache_tools="default")

tru_request = model._format_request(messages, tool_specs=[tool_spec])

tools = tru_request["toolConfig"]["tools"]
cache_point_block = tools[-1]
assert "cachePoint" in cache_point_block
assert cache_point_block["cachePoint"]["type"] == "default"
assert "ttl" not in cache_point_block["cachePoint"]


def test_format_bedrock_content_block_cache_point_with_ttl(model):
"""Test that _format_request_message_content preserves TTL on cachePoint blocks."""
content = {"cachePoint": {"type": "default", "ttl": "1h"}}

result = model._format_request_message_content(content)

assert result == {"cachePoint": {"type": "default", "ttl": "1h"}}


def test_format_bedrock_content_block_cache_point_without_ttl(model):
"""Test that _format_request_message_content does not add TTL when not present."""
content = {"cachePoint": {"type": "default"}}

result = model._format_request_message_content(content)

assert result == {"cachePoint": {"type": "default"}}
assert "ttl" not in result["cachePoint"]
Loading