diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py new file mode 100644 index 0000000000..3ac201f7fe --- /dev/null +++ b/cassandra/client_routes.py @@ -0,0 +1,400 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Client Routes support for Private Link and similar network configurations. + +This module implements support for dynamic address translation via the +system.client_routes table and CLIENT_ROUTES_CHANGE events. +""" + +from __future__ import absolute_import + +from dataclasses import dataclass, replace +import logging +import threading +import uuid +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set + +from cassandra.query import dict_factory + +if TYPE_CHECKING: + from cassandra.cluster import ControlConnection + +log = logging.getLogger(__name__) + + +class ClientRoutesEndpoint: + + def __init__(self, connection_id: str, connection_addr: Optional[str] = None): + """ + :param connection_id: UUID string identifying the connection + :param connection_addr: Optional string address for initial connection + """ + if connection_id is None: + raise ValueError("connection_id is required") + + self.connection_id = connection_id + self.connection_addr = connection_addr + + def __repr__(self) -> str: + return f"ClientRoutesEndpoint(connection_id={self.connection_id}, connection_addr={self.connection_addr})" + + +class ClientRoutesConfig: + """ + Configuration for client routes (Private Link support). + + :param endpoints: List of ClientRoutesEndpoint objects (REQUIRED, at least one) + :param table_name: Name of the system table to query (default: "system.client_routes") + :param cache_ttl_seconds: How long DNS resolution results are cached (default: 300 seconds = 5 minutes) + """ + + def __init__(self, endpoints: List[ClientRoutesEndpoint], table_name: str = "system.client_routes", + cache_ttl_seconds: int = 300): + """ + :param endpoints: List of ClientRoutesEndpoint objects + :param table_name: System table name for route discovery + :param cache_ttl_seconds: DNS cache TTL in seconds (must be >= 0, default: 300 = 5 minutes) + """ + if not endpoints: + raise ValueError("At least one endpoint must be specified") + + if not isinstance(endpoints, (list, tuple)): + raise TypeError("endpoints must be a list or tuple") + + for endpoint in endpoints: + if not isinstance(endpoint, ClientRoutesEndpoint): + raise TypeError("All endpoints must be ClientRoutesEndpoint instances") + + if cache_ttl_seconds < 0: + raise ValueError("cache_ttl_seconds must be >= 0") + + self.endpoints = list(endpoints) + self.table_name = table_name + self.cache_ttl_seconds = cache_ttl_seconds + + def __repr__(self) -> str: + return f"ClientRoutesConfig(endpoints={self.endpoints}, table_name={self.table_name})" + + +# Internal data structures + +@dataclass +class ResolvedRoute: + connection_id: str + host_id: uuid.UUID + address: str # DNS hostname from system.client_routes + port: int + tls_port: Optional[int] + datacenter: Optional[str] + rack: Optional[str] + all_known_ips: Optional[List[str]] # List of all resolved IP addresses + current_ip: Optional[str] # Currently selected IP address + update_time: Optional[float] # Timestamp of last resolution + forced_resolve: bool # Flag to force resolution on next cycle + + # Compatibility helper for existing namedtuple-style updates in the file + def _replace(self, **changes) -> 'ResolvedRoute': + return replace(self, **changes) + +class ResolvedRoutes: + """ + Thread-safe storage for resolved routes using lock-free reads. + + This uses atomic pointer swaps for updates, allowing lock-free reads + while serializing writes. + """ + + def __init__(self) -> None: + self._routes_by_host_id: Dict[uuid.UUID, ResolvedRoute] = {} # Dict[UUID, ResolvedRoute] + self._lock = threading.RLock() + + def get_by_host_id(self, host_id: uuid.UUID) -> Optional[ResolvedRoute]: + """ + Get route for a host ID (lock-free read). + + :param host_id: UUID of the host + :return: ResolvedRoute or None + """ + return self._routes_by_host_id.get(host_id) + + def get_all(self) -> List[ResolvedRoute]: + """ + Get all routes as a list (lock-free read). + + :return: List of ResolvedRoute + """ + return list(self._routes_by_host_id.values()) + + def update(self, routes: List[ResolvedRoute]) -> None: + """ + Replace all routes atomically. + + :param routes: List of ResolvedRoute objects + """ + with self._lock: + self._routes_by_host_id = {route.host_id: route for route in routes} + + def merge(self, new_routes: List[ResolvedRoute]) -> None: + """ + Merge new routes with existing ones atomically. + + :param new_routes: List of ResolvedRoute objects to merge + """ + with self._lock: + updated = dict(self._routes_by_host_id) + for route in new_routes: + updated[route.host_id] = route + self._routes_by_host_id = updated + + def merge_with_unresolved(self, new_routes: List[ResolvedRoute]) -> None: + """ + Merge unresolved routes, marking changed ones for forced resolution. + + :param new_routes: List of ResolvedRoute objects from system.client_routes + """ + with self._lock: + updated = dict(self._routes_by_host_id) + + for new_route in new_routes: + key = new_route.host_id + existing = updated.get(key) + + if existing is None: + # New route, add with forced_resolve=True + updated[key] = new_route._replace(forced_resolve=True) + else: + # Check if route details changed (address, port, tls_port) + if (existing.connection_id != new_route.connection_id or + existing.address != new_route.address or + existing.port != new_route.port or + existing.tls_port != new_route.tls_port): + # Route changed, mark for forced resolution + updated[key] = new_route._replace(forced_resolve=True) + # Otherwise keep existing route with its resolution state + + self._routes_by_host_id = updated + + def update_single(self, host_id: uuid.UUID, update_fn: Callable[[ResolvedRoute], ResolvedRoute]) -> Optional[ResolvedRoute]: + """ + Update a single route using CAS (compare-and-swap) pattern. + + :param host_id: UUID of the host to update + :param update_fn: Function that takes existing route and returns updated route + :return: Updated ResolvedRoute or None if not found + """ + with self._lock: + existing = self._routes_by_host_id.get(host_id) + if existing: + updated = update_fn(existing) + self._routes_by_host_id[host_id] = updated + return updated + return None + + + +class ClientRoutesHandler: + """ + Handles dynamic address translation for Private Link via system.client_routes. + + Lifecycle: + 1. Construction: Create with configuration + 2. Initialization: Read system.client_routes after control connection established + 3. Steady state: Listen for CLIENT_ROUTES_CHANGE events and update routes + 4. Translation: Translate addresses using Host ID lookup + 5. Shutdown: Clean up resources + """ + + def __init__(self, config: 'ClientRoutesConfig', ssl_enabled: bool = False): + """ + :param config: ClientRoutesConfig instance + :param ssl_enabled: Whether TLS is enabled (determines port selection) + """ + if not isinstance(config, ClientRoutesConfig): + raise TypeError("config must be a ClientRoutesConfig instance") + + self.config: ClientRoutesConfig = config + self.ssl_enabled: bool = ssl_enabled + self._routes: ResolvedRoutes = ResolvedRoutes() + self._initial_endpoints: Set[str] = {ep.connection_id for ep in config.endpoints} + self._is_shutdown: bool = False + self._lock = threading.RLock() + + def initialize(self, control_connection: 'ControlConnection') -> None: + """ + Initialize handler after control connection is established. + + Reads system.client_routes for all configured connection IDs. + DNS resolution happens at the Endpoint level on each connection attempt. + + :param control_connection: The ControlConnection instance + """ + if self._is_shutdown: + return + + log.info("[client routes] Initializing with %d endpoints", len(self.config.endpoints)) + + try: + connection_ids = [ep.connection_id for ep in self.config.endpoints] + routes = self._query_routes(control_connection, connection_ids=connection_ids) + + self._routes.merge_with_unresolved(routes) + + log.info("[client routes] Initialized with %d routes", len(self._routes.get_all())) + except Exception as e: + log.error("[client routes] Initialization failed: %s", e, exc_info=True) + raise + + def handle_client_routes_change(self, control_connection: 'ControlConnection', change_type: str, + connection_ids: Optional[List[str]], host_ids: Optional[List[str]]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event. + + :param control_connection: The ControlConnection instance + :param change_type: Type of change (e.g., "UPDATED") + :param connection_ids: List of affected connection ID strings + :param host_ids: List of affected host ID strings + """ + if self._is_shutdown: + return + + log.debug("[client routes] Handling CLIENT_ROUTES_CHANGE: change_type=%s, " + "connection_ids=%s, host_ids=%s", + change_type, connection_ids, host_ids) + + try: + filtered_conn_ids = None + if connection_ids: + configured_ids = {str(ep.connection_id) for ep in self.config.endpoints} + filtered = [cid for cid in connection_ids if cid in configured_ids] + if not filtered: + log.debug("[client routes] All connection IDs filtered out, ignoring event") + return + filtered_conn_ids = [uuid.UUID(cid) for cid in filtered] + + host_uuids = [uuid.UUID(hid) for hid in host_ids] if host_ids else None + + routes = self._query_routes( + control_connection, + connection_ids=filtered_conn_ids, + host_ids=host_uuids + ) + + self._routes.merge_with_unresolved(routes) + + log.debug("[client routes] Updated routes after CLIENT_ROUTES_CHANGE") + except Exception as e: + log.warning("[client routes] Failed to handle CLIENT_ROUTES_CHANGE: %s", e, exc_info=True) + + def handle_control_connection_reconnect(self, control_connection: 'ControlConnection') -> None: + """ + Handle control connection recreation - full re-read of all connection IDs. + + :param control_connection: The new ControlConnection instance + """ + if self._is_shutdown: + return + + log.info("[client routes] Control connection reconnected, re-reading all routes") + + try: + self.initialize(control_connection) + except Exception as e: + log.error("[client routes] Failed to re-initialize after reconnect: %s", e, exc_info=True) + + def _query_routes(self, control_connection: 'ControlConnection', connection_ids: Optional[List[uuid.UUID]] = None, + host_ids: Optional[List[uuid.UUID]] = None) -> List[ResolvedRoute]: + """ + Query system.client_routes table. + + :param control_connection: ControlConnection to execute query + :param connection_ids: Optional list of connection UUIDs to filter by + :param host_ids: Optional list of host UUIDs to filter by + :return: List of ResolvedRoute (with resolved_ip/resolved_at as None) + """ + query_parts = [f"SELECT * FROM {self.config.table_name}"] + where_clauses = [] + + if connection_ids: + conn_id_list = ', '.join(str(cid) for cid in connection_ids) + where_clauses.append(f"connection_id IN ({conn_id_list})") + + if host_ids: + host_id_list = ', '.join(str(hid) for hid in host_ids) + where_clauses.append(f"host_id IN ({host_id_list})") + + if where_clauses: + query_parts.append("WHERE " + " AND ".join(where_clauses)) + + if (not connection_ids or len(connection_ids) == 0) and (not host_ids or len(host_ids) == 0): + query_parts.append("ALLOW FILTERING") + query = " ".join(query_parts) + + log.debug("[client routes] Querying: %s", query) + + from cassandra.protocol import QueryMessage + from cassandra import ConsistencyLevel + + query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + result = control_connection._connection.wait_for_response( + query_msg, timeout=control_connection._timeout + ) + + routes = [] + if hasattr(result, 'parsed_rows') and result.parsed_rows: + rows = dict_factory( + result.column_names, + result.parsed_rows) + for row in rows: + try: + routes.append(ResolvedRoute( + connection_id=row['connection_id'], + host_id=row['host_id'], + address=row['address'], + port=row['port'], + tls_port=row.get('tls_port'), + datacenter=row.get('datacenter'), + rack=row.get('rack'), + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True # Force initial resolution + )) + except Exception as e: + log.warning("[client routes] Failed to parse route row: %s", e) + + return routes + + def get_route_by_host_id(self, host_id: uuid.UUID) -> Optional[ResolvedRoute]: + """ + Get route information for a given host_id. + + This is used by ClientRoutesEndPointFactory to create endpoints. + + :param host_id: Host UUID + :return: ResolvedRoute or None + """ + return self._routes.get_by_host_id(host_id) + + def shutdown(self) -> None: + """ + Shutdown the handler and release resources. + """ + with self._lock: + if self._is_shutdown: + return + + self._is_shutdown = True + log.info("[client routes] Handler shutdown") diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 622b706330..7f1db553de 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -29,7 +29,7 @@ from itertools import groupby, count, chain import json import logging -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from warnings import warn from random import random import re @@ -48,7 +48,8 @@ SchemaTargetType, DriverException, ProtocolVersion, UnresolvableContactPoints, DependencyException) from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider -from cassandra.connection import (ConnectionException, ConnectionShutdown, +from cassandra.client_routes import ClientRoutesConfig, ClientRoutesHandler +from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, SniEndPointFactory, ConnectionBusy, locally_supported_compressions) @@ -804,10 +805,10 @@ def default_retry_policy(self, policy): :class:`.policies.SimpleConvictionPolicy`. """ - address_translator = IdentityTranslator() + address_translator = None """ :class:`.policies.AddressTranslator` instance to be used in translating server node addresses - to driver connection addresses. + to driver connection addresses. If :const:`None`, addresses are not translated. """ connect_to_remote_hosts = True @@ -1215,7 +1216,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info:Optional[ApplicationInfoBase]=None, + client_routes_config:Optional[ClientRoutesConfig]=None ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1280,6 +1282,36 @@ def __init__(self, if column_encryption_policy is not None: self.column_encryption_policy = column_encryption_policy + # Validate mutual exclusivity of client_routes_config and address_translator + if client_routes_config is not None and address_translator is not None: + raise ValueError("client_routes_config and address_translator are mutually exclusive") + + # Handle client routes configuration + self._client_routes_handler = None + if client_routes_config is not None: + if not isinstance(client_routes_config, ClientRoutesConfig): + raise TypeError("client_routes_config must be a ClientRoutesConfig instance") + + ssl_enabled = ssl_context is not None or ssl_options is not None + self._client_routes_handler = ClientRoutesHandler(client_routes_config, ssl_enabled=ssl_enabled) + + if contact_points is _NOT_SET or not self._contact_points_explicit: + seed_addrs = [ep.connection_addr for ep in client_routes_config.endpoints + if ep.connection_addr] + if seed_addrs: + self.contact_points = seed_addrs + self._contact_points_explicit = True + log.info("[client routes] Using %d endpoint connection addresses as contact points", + len(seed_addrs)) + elif address_translator is not None: + if isinstance(address_translator, type): + raise TypeError("address_translator should not be a class, it should be an instance of that class") + self.address_translator = address_translator + else: + self.address_translator = IdentityTranslator() + + if self._client_routes_handler is not None: + endpoint_factory = ClientRoutesEndPointFactory(self._client_routes_handler, self.port) self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) self.endpoint_factory.configure(self) @@ -1337,11 +1369,6 @@ def __init__(self, raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory - if address_translator is not None: - if isinstance(address_translator, type): - raise TypeError("address_translator should not be a class, it should be an instance of that class") - self.address_translator = address_translator - if application_info is not None: if not isinstance(application_info, ApplicationInfoBase): raise TypeError( @@ -1798,6 +1825,12 @@ def shutdown(self): if self.metrics_enabled and self.metrics: self.metrics.shutdown() + if self._client_routes_handler is not None: + try: + self._client_routes_handler.shutdown() + except Exception: + log.warning("Error shutting down client routes handler", exc_info=True) + _discard_cluster_shutdown(self) def __enter__(self): @@ -3612,11 +3645,24 @@ def _try_connect(self, endpoint): # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: - connection.register_watchers({ + watchers = { "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') - }, register_timeout=self._timeout) + } + + if self._cluster._client_routes_handler is not None: + watchers["CLIENT_ROUTES_CHANGE"] = partial(_watch_callback, self_weakref, '_handle_client_routes_change') + + connection.register_watchers(watchers, register_timeout=self._timeout) + + if self._cluster._client_routes_handler is not None: + try: + self._cluster._client_routes_handler.initialize(self) + except Exception as e: + log.error("[control connection] Failed to initialize client routes handler: %s", e, exc_info=True) + connection.close() + raise sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS @@ -3658,6 +3704,13 @@ def _reconnect(self): log.debug("[control connection] Attempting to reconnect") try: self._set_new_connection(self._reconnect_internal()) + + # Notify client routes handler of reconnection (full re-read) + if self._cluster._client_routes_handler is not None: + try: + self._cluster._client_routes_handler.handle_control_connection_reconnect(self) + except Exception as e: + log.warning("[control connection] Failed to notify client routes handler of reconnection: %s", e) except NoHostAvailable: # make a retry schedule (which includes backoff) schedule = self._cluster.reconnection_policy.new_schedule() @@ -3979,6 +4032,31 @@ def _handle_status_change(self, event): # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) + def _handle_client_routes_change(self, event: Dict[str, Any]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event from the server. + + This event indicates that the system.client_routes table has been updated + and we need to refresh our route mappings. + """ + if self._cluster._client_routes_handler is None: + log.warning("[control connection] Received CLIENT_ROUTES_CHANGE but no handler configured") + return + + change_type = event.get("change_type") + connection_ids = event.get("connection_ids", []) + host_ids = event.get("host_ids", []) + + log.debug("[control connection] Received CLIENT_ROUTES_CHANGE: change_type=%s, " + "connection_ids=%s, host_ids=%s", change_type, connection_ids, host_ids) + + # Handle the event asynchronously + self._cluster.scheduler.schedule_unique( + 0, + self._cluster._client_routes_handler.handle_client_routes_change, + self, change_type, connection_ids, host_ids + ) + def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..e4cf7ccf5a 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -25,12 +25,14 @@ from threading import Thread, Event, RLock, Condition import time import ssl +import uuid import weakref import random import itertools -from typing import Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from cassandra.application_info import ApplicationInfoBase +from cassandra.client_routes import ClientRoutesHandler from cassandra.protocol_features import ProtocolFeatures if 'gevent.monkey' in sys.modules: @@ -230,7 +232,7 @@ class DefaultEndPointFactory(EndPointFactory): port = None """ If no port is discovered in the row, this is the default port - used for endpoint creation. + used for endpoint creation. """ def __init__(self, port=None): @@ -328,6 +330,47 @@ def create_from_sni(self, sni): return SniEndPoint(self._proxy_address, sni, self._port) +class ClientRoutesEndPointFactory(EndPointFactory): + """ + EndPointFactory for Client Routes (Private Link) support. + + Creates ClientRoutesEndPoint instances that defer both address translation + (host_id -> hostname lookup) and DNS resolution until connection time. + This ensures immediate reaction to infrastructure changes. + """ + + def __init__(self, client_routes_handler: ClientRoutesHandler, port: Optional[int] = None) -> None: + """ + :param client_routes_handler: ClientRoutesHandler instance to lookup routes + :param port: Default port if none found in row + """ + self.client_routes_handler = client_routes_handler + self.default_port = port + + def create(self, row: Dict[str, Any]) -> 'ClientRoutesEndPoint': + """ + Create a ClientRoutesEndPoint from a system.peers row. + + Stores only the host_id and handler reference. Both translation + (route lookup) and DNS resolution happen later in resolve(). + """ + from cassandra.metadata import _NodeInfo + host_id = row.get("host_id") + + if host_id is None: + raise ValueError("No host_id to create ClientRoutesEndPoint") + + # Store original address for identification purposes + addr = _NodeInfo.get_broadcast_rpc_address(row) + + return ClientRoutesEndPoint( + host_id=host_id, + handler=self.client_routes_handler, + original_address=addr, + default_port=self.default_port + ) + + @total_ordering class UnixSocketEndPoint(EndPoint): """ @@ -369,6 +412,140 @@ def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) +@total_ordering +class ClientRoutesEndPoint(EndPoint): + """ + Client Routes (Private Link) EndPoint implementation. + + Defers both address translation (route lookup) and DNS resolution + until resolve() is called at connection time. This ensures immediate + reaction to infrastructure changes and CLIENT_ROUTES_CHANGE events. + """ + + def __init__(self, host_id: uuid.UUID, handler: ClientRoutesHandler, original_address: str, default_port: Optional[int] = None) -> None: + """ + :param host_id: Host UUID for route lookup + :param handler: ClientRoutesHandler instance + :param original_address: Original address from system.peers (for identification) + :param default_port: Default port if route doesn't specify one + """ + self._host_id = host_id + self._handler = handler + self._original_address = original_address + self._default_port = default_port + + @property + def address(self) -> str: + """Returns the original address for identification""" + return self._original_address + + @property + def port(self) -> int: + # Port is not known until resolve() - return a placeholder + return self._default_port if self._default_port else 9042 + + @property + def host_id(self) -> uuid.UUID: + return self._host_id + + def resolve(self) -> Tuple[str, int]: + """ + Perform both translation and DNS resolution at connection time. + + Steps: + 1. Translation: Look up current route by host_id to get DNS hostname + 2. Check cache: Use cached IPs if still valid (based on cache_ttl_seconds) + 3. DNS Resolution: If cache miss/expired, resolve hostname to IP and cache results + + This is called on every connection attempt, ensuring: + - Fresh route lookup (picks up CLIENT_ROUTES_CHANGE events) + - Cached DNS resolution (reduces load, configurable TTL) + """ + # Step 1: Translation - look up current route by host_id + route = self._handler.get_route_by_host_id(self._host_id) + + if route is None: + raise ConnectionException( + f"No client route found for host_id={self._host_id}", + endpoint=self + ) + + hostname = route.address + + port = route.port + if self._handler.ssl_enabled and route.tls_port: + port = route.tls_port + elif self._default_port: + port = self._default_port + + # Step 2: Check cache validity + cache_ttl = self._handler.config.cache_ttl_seconds + current_time = time.time() + cache_valid = ( + route.all_known_ips is not None and + len(route.all_known_ips) > 0 and + route.update_time is not None and + (current_time - route.update_time) < cache_ttl and + not route.forced_resolve + ) + + if cache_valid: + # Use cached IP (prefer current_ip if set, otherwise first from list) + resolved_ip = route.current_ip if route.current_ip else route.all_known_ips[0] + log.debug("ClientRoutesEndPoint using cached IP for host_id=%s -> %s (age: %.1fs)", + self._host_id, resolved_ip, current_time - route.update_time) + return resolved_ip, port + + # Step 3: DNS Resolution - resolve hostname to IP and cache results + try: + result = socket.getaddrinfo(hostname, port, + socket.AF_INET, socket.SOCK_STREAM) + if not result: + raise socket.gaierror("No addresses found for %s" % hostname) + + # Extract all resolved IPs + all_ips = [addr[4][0] for addr in result] + resolved_ip = all_ips[0] + + # Update route with resolved IPs and timestamp + def update_route(existing_route): + return existing_route._replace( + all_known_ips=all_ips, + current_ip=resolved_ip, + update_time=current_time, + forced_resolve=False + ) + + self._handler._routes.update_single(self._host_id, update_route) + + log.debug("ClientRoutesEndPoint resolved and cached host_id=%s -> %s -> %s (found %d IPs)", + self._host_id, hostname, resolved_ip, len(all_ips)) + return resolved_ip, port + except socket.gaierror as e: + log.warning('Could not resolve client routes hostname "%s" (host_id=%s): %s', + hostname, self._host_id, e) + raise + + def __eq__(self, other): + return (isinstance(other, ClientRoutesEndPoint) and + self._host_id == other._host_id and + self._original_address == other._original_address) + + def __hash__(self): + return hash((self._host_id, self._original_address)) + + def __lt__(self, other): + return ((str(self._host_id), self._original_address) < + (str(other._host_id), other._original_address)) + + def __str__(self): + return str("%s (host_id=%s)" % (self._original_address, self._host_id)) + + def __repr__(self): + return "<%s: host_id=%s, original_addr=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address) + + class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..92d10bb622 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1241,7 +1241,6 @@ def translate(self, addr): pass return addr - class SpeculativeExecutionPolicy(object): """ Interface for specifying speculative execution plans diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py new file mode 100644 index 0000000000..ce3bf26a3e --- /dev/null +++ b/tests/integration/standard/test_client_routes.py @@ -0,0 +1,185 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import uuid + +from cassandra.cluster import Cluster +from cassandra.client_routes import ClientRoutesConfig, ClientRoutesEndpoint, ClientRoutesHandler +from tests.integration import TestCluster, use_cluster + +def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes', [3], start=True) + +class TestGetHostPortMapping(unittest.TestCase): + """ + Test _query_routes method with different filtering scenarios. + """ + + @classmethod + def setUpClass(cls): + """Create test keyspace and table, populate with test data.""" + cls.cluster = TestCluster() + cls.session = cls.cluster.connect() + + cls.session.execute(""" + CREATE KEYSPACE IF NOT EXISTS gocql_test + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} + """) + + cls.session.execute(""" + CREATE TABLE IF NOT EXISTS gocql_test.client_routes ( + connection_id uuid, + host_id uuid, + address text, + port int, + tls_port int, + alternator_port int, + alternator_https_port int, + datacenter text, + rack text, + PRIMARY KEY (connection_id, host_id) + ) + """) + + cls.session.execute("TRUNCATE gocql_test.client_routes") + + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [uuid.uuid4() for _ in range(3)] + cls.racks = ["rack1", "rack2", "rack3"] + cls.expected = [] + + for idx, host_id in enumerate(cls.host_ids): + rack = cls.racks[idx] + ip = f"127.0.0.{idx + 1}" + + for connection_id in cls.connection_ids: + cls.session.execute( + """ + INSERT INTO gocql_test.client_routes + (connection_id, host_id, address, port, tls_port, + alternator_port, alternator_https_port, datacenter, rack) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (connection_id, host_id, ip, 9042, 9142, 0, 0, 'dc1', rack) + ) + + cls.expected.append({ + 'connection_id': connection_id, + 'host_id': host_id, + 'address': ip, + 'port': 9042, + 'tls_port': 9142, + 'datacenter': 'dc1', + 'rack': rack + }) + + cls._sort_routes(cls.expected) + + @classmethod + def tearDownClass(cls): + """Clean up test keyspace.""" + try: + cls.session.execute("DROP KEYSPACE IF EXISTS gocql_test") + finally: + cls.cluster.shutdown() + + @staticmethod + def _sort_routes(routes): + """Sort routes by connection_id then host_id for deterministic comparison.""" + routes.sort(key=lambda r: (str(r['connection_id']), str(r['host_id']))) + + def _query_and_compare(self, connection_ids, host_ids, expected): + """ + Query routes using ClientRoutesHandler._query_routes and compare with expected. + + :param connection_ids: List of connection UUIDs or None + :param host_ids: List of host UUIDs or None + :param expected: Expected list of route dicts + """ + config = ClientRoutesConfig( + endpoints=[ClientRoutesEndpoint( + connection_id=str(self.connection_ids[0]), + connection_addr="127.0.0.1" + )], + table_name="gocql_test.client_routes" + ) + handler = ClientRoutesHandler(config) + + routes = handler._query_routes( + self.cluster.control_connection, + connection_ids=connection_ids, + host_ids=host_ids + ) + + got = [] + for route in routes: + got.append({ + 'connection_id': route.connection_id, + 'host_id': route.host_id, + 'address': route.address, + 'port': route.port, + 'tls_port': route.tls_port, + 'datacenter': route.datacenter, + 'rack': route.rack + }) + + self._sort_routes(got) + + self.assertEqual(len(got), len(expected), + f"Expected {len(expected)} routes, got {len(got)}") + + for i, (got_route, expected_route) in enumerate(zip(got, expected)): + self.assertEqual(got_route['connection_id'], expected_route['connection_id'], + f"Route {i}: connection_id mismatch") + self.assertEqual(got_route['host_id'], expected_route['host_id'], + f"Route {i}: host_id mismatch") + self.assertEqual(got_route['address'], expected_route['address'], + f"Route {i}: address mismatch") + self.assertEqual(got_route['port'], expected_route['port'], + f"Route {i}: port mismatch") + self.assertEqual(got_route['tls_port'], expected_route['tls_port'], + f"Route {i}: tls_port mismatch") + + def test_get_all(self): + """Test querying all routes without filters.""" + self._query_and_compare(None, None, self.expected) + + def test_get_all_hosts(self): + """Test querying with connection_ids filter only.""" + self._query_and_compare(self.connection_ids, None, self.expected) + + def test_get_all_connections(self): + """Test querying with host_ids filter only.""" + self._query_and_compare(None, self.host_ids, self.expected) + + def test_get_concrete(self): + """Test querying with both connection_ids and host_ids filters.""" + self._query_and_compare(self.connection_ids, self.host_ids, self.expected) + + def test_get_concrete_host(self): + """Test querying specific connection and host combination.""" + filtered_expected = [ + r for r in self.expected + if r['connection_id'] == self.connection_ids[0] and + r['host_id'] == self.host_ids[0] + ] + + self._query_and_compare( + [self.connection_ids[0]], + [self.host_ids[0]], + filtered_expected + ) diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py new file mode 100644 index 0000000000..5b1674567f --- /dev/null +++ b/tests/unit/test_client_routes.py @@ -0,0 +1,284 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import uuid +import time +from unittest.mock import Mock, patch + +from cassandra.client_routes import ( + ClientRoutesEndpoint, + ClientRoutesConfig, + ResolvedRoutes, + ResolvedRoute, + ClientRoutesHandler +) +from cassandra import DriverException + + +class TestClientRoutesEndpoint(unittest.TestCase): + + def test_endpoint_none_connection_id(self): + with self.assertRaises(ValueError): + ClientRoutesEndpoint(None) + + +class TestClientRoutesConfig(unittest.TestCase): + + def test_config_with_endpoints(self): + ep1 = ClientRoutesEndpoint(str(uuid.uuid4()), "10.0.0.1") + ep2 = ClientRoutesEndpoint(str(uuid.uuid4()), "10.0.0.2") + config = ClientRoutesConfig([ep1, ep2]) + self.assertEqual(len(config.endpoints), 2) + self.assertEqual(config.table_name, "system.client_routes") + + def test_config_custom_table_name(self): + ep = ClientRoutesEndpoint(str(uuid.uuid4())) + config = ClientRoutesConfig([ep], table_name="custom.routes") + self.assertEqual(config.table_name, "custom.routes") + + def test_config_empty_endpoints(self): + with self.assertRaises(ValueError): + ClientRoutesConfig([]) + + def test_config_invalid_endpoint_type(self): + with self.assertRaises(TypeError): + ClientRoutesConfig(["not-an-endpoint"]) + + def test_config_cache_ttl_validation(self): + ep = ClientRoutesEndpoint(str(uuid.uuid4())) + # Must be >= 0 + with self.assertRaises(ValueError): + ClientRoutesConfig([ep], cache_ttl_seconds=-1) + # 0 is valid (never cache) + config = ClientRoutesConfig([ep], cache_ttl_seconds=0) + self.assertEqual(config.cache_ttl_seconds, 0) + + +class TestResolvedRoutes(unittest.TestCase): + + def test_get_by_host_id(self): + routes = ResolvedRoutes() + host_id = uuid.uuid4() + route = ResolvedRoute( + connection_id=uuid.uuid4(), + host_id=host_id, + address="example.com", + port=9042, + tls_port=9142, + datacenter="dc1", + rack="rack1", + all_known_ips=["192.168.1.1", "192.168.1.2"], + current_ip="192.168.1.1", + update_time=12345.0, + forced_resolve=False + ) + + routes.update([route]) + + retrieved = routes.get_by_host_id(host_id) + self.assertEqual(retrieved.host_id, host_id) + self.assertEqual(retrieved.address, "example.com") + + def test_merge_routes(self): + routes = ResolvedRoutes() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + + route1 = ResolvedRoute( + connection_id=uuid.uuid4(), host_id=host_id1, + address="host1.com", port=9042, tls_port=None, + datacenter="dc1", rack="rack1", + all_known_ips=["192.168.1.1"], current_ip="192.168.1.1", + update_time=12345.0, forced_resolve=False + ) + + route2 = ResolvedRoute( + connection_id=uuid.uuid4(), host_id=host_id2, + address="host2.com", port=9042, tls_port=None, + datacenter="dc1", rack="rack1", + all_known_ips=["192.168.1.2"], current_ip="192.168.1.2", + update_time=12345.0, forced_resolve=False + ) + + routes.update([route1]) + routes.merge([route2]) + + self.assertIsNotNone(routes.get_by_host_id(host_id1)) + self.assertIsNotNone(routes.get_by_host_id(host_id2)) + + +class TestClientRoutesHandler(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.endpoint = ClientRoutesEndpoint(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.endpoint]) + + def test_handler_initialization(self): + handler = ClientRoutesHandler(self.config, ssl_enabled=False) + self.assertIsNotNone(handler) + self.assertEqual(handler.ssl_enabled, False) + + @patch.object(ClientRoutesHandler, '_query_routes') + def test_initialize(self, mock_query): + host_id = uuid.uuid4() + mock_query.return_value = [ + ResolvedRoute( + connection_id=self.conn_id, + host_id=host_id, + address="node1.example.com", + port=9042, + tls_port=9142, + datacenter="dc1", + rack="rack1", + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True + ) + ] + + handler = ClientRoutesHandler(self.config) + mock_control_conn = Mock() + + handler.initialize(mock_control_conn) + + mock_query.assert_called_once() + # Verify route was stored + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "node1.example.com") + +class TestQueryBuilding(unittest.TestCase): + + def test_query_all_routes(self): + """Test querying all routes without filters.""" + conn_id = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(str(conn_id), "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + # Mock response + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query all routes + handler._query_routes(mock_control_conn, connection_ids=None, host_ids=None) + + # Verify query structure + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("select * from", query_str) + self.assertIn("allow filtering", query_str) + self.assertNotIn("where", query_str) + + def test_query_with_connection_ids_only(self): + """Test querying with connection_ids filter only.""" + conn_id1 = uuid.uuid4() + conn_id2 = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(str(conn_id1), "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query with connection_ids filter + handler._query_routes(mock_control_conn, connection_ids=[conn_id1, conn_id2], host_ids=None) + + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("where", query_str) + self.assertIn("connection_id in", query_str) + self.assertNotIn("host_id in", query_str) + + def test_query_with_host_ids_only(self): + """Test querying with host_ids filter only.""" + conn_id = uuid.uuid4() + host_id = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(str(conn_id), "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query with host_ids filter + handler._query_routes(mock_control_conn, connection_ids=None, host_ids=[host_id]) + + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("where", query_str) + self.assertIn("host_id in", query_str) + self.assertNotIn("connection_id in", query_str) + + def test_query_with_both_filters(self): + """Test querying with both connection_ids and host_ids filters.""" + conn_id = uuid.uuid4() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(str(conn_id), "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query with both filters + handler._query_routes(mock_control_conn, connection_ids=[conn_id], host_ids=[host_id1, host_id2]) + + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("where", query_str) + self.assertIn("connection_id in", query_str) + self.assertIn("host_id in", query_str) + # When both filters present, should not use ALLOW FILTERING + self.assertNotIn("allow filtering", query_str) + + +if __name__ == '__main__': + unittest.main()