-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathstale_dns_plugin.py
More file actions
211 lines (163 loc) · 8.39 KB
/
stale_dns_plugin.py
File metadata and controls
211 lines (163 loc) · 8.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import socket
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set
if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.properties import Properties
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.hostinfo import HostRole
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.notifications import HostEvent
from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils
from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils
logger = Logger(__name__)
class StaleDnsHelper:
RETRIES: int = 3
def __init__(self, plugin_service: PluginService) -> None:
self._plugin_service = plugin_service
self._rds_helper = RdsUtils()
self._writer_host_info: Optional[HostInfo] = None
self._writer_host_address: Optional[str] = None
def get_verified_connection(self, is_initial_connection: bool, host_list_provider_service: HostListProviderService, host_info: HostInfo,
props: Properties, connect_func: Callable) -> Connection:
"""
Ensure the connection created is not a stale writer connection that
:param is_initial_connection:
:param host_list_provider_service:
:param host_info:
:param props:
:param connect_func:
:return:
"""
if not self._rds_helper.is_writer_cluster_dns(host_info.host) \
and not self._rds_helper.is_global_db_writer_cluster_dns(host_info.host):
return connect_func()
conn: Connection = connect_func()
cluster_inet_address: Optional[str] = None
try:
cluster_inet_address = socket.gethostbyname(host_info.host)
except socket.gaierror:
pass
host_inet_address: Optional[str] = cluster_inet_address
logger.debug("StaleDnsHelper.ClusterEndpointDns", host_info.host, host_inet_address)
if cluster_inet_address is None:
return conn
connected_to_reader = self._plugin_service.get_host_role(conn) == HostRole.READER
if connected_to_reader:
# This if-statement is only reached if the connection url is a writer cluster endpoint.
# If the new connection resolves to a reader instance, this means the topology is outdated.
# Force refresh to update the topology.
self._plugin_service.force_refresh_host_list(conn)
else:
self._plugin_service.refresh_host_list(conn)
logger.debug("LogUtils.Topology", LogUtils.log_topology(self._plugin_service.all_hosts))
if self._writer_host_info is None:
writer_candidate: Optional[HostInfo] = self._get_writer()
if writer_candidate is not None and self._rds_helper.is_rds_cluster_dns(writer_candidate.host):
return conn
self._writer_host_info = writer_candidate
logger.debug("StaleDnsHelper.WriterHostSpec", self._writer_host_info)
if self._writer_host_info is None:
return conn
if self._writer_host_address is None:
try:
self._writer_host_address = socket.gethostbyname(self._writer_host_info.host)
except socket.gaierror:
pass
logger.debug("StaleDnsHelper.WriterInetAddress", self._writer_host_address)
if self._writer_host_address is None:
return conn
if self._writer_host_address != cluster_inet_address or connected_to_reader:
logger.debug("StaleDnsHelper.StaleDnsDetected", self._writer_host_info)
allowed_hosts = self._plugin_service.hosts
if not Utils.contains_host_and_port(tuple(allowed_hosts), self._writer_host_info.get_host_and_port()):
raise AwsWrapperError(
Messages.get_formatted(
"StaleDnsHelper.CurrentWriterNotAllowed",
"<null>" if self._writer_host_info is None else self._writer_host_info.get_host_and_port(),
LogUtils.log_topology(allowed_hosts)))
writer_conn: Connection = self._plugin_service.connect(self._writer_host_info, props)
if is_initial_connection:
host_list_provider_service.initial_connection_host_info = self._writer_host_info
if conn is not None:
try:
conn.close()
except Exception:
pass
return writer_conn
return conn
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]) -> None:
if self._writer_host_info is None:
return
writer_changes = changes.get(self._writer_host_info.url, None)
if writer_changes is not None and HostEvent.CONVERTED_TO_READER in writer_changes:
logger.debug("StaleDnsHelper.Reset")
self._writer_host_info = None
self._writer_host_address = None
def _get_writer(self) -> Optional[HostInfo]:
for host in self._plugin_service.all_hosts:
if host.role == HostRole.WRITER:
return host
return None
class StaleDnsPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.INIT_HOST_PROVIDER.method_name,
DbApiMethod.CONNECT.method_name,
DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name}
def __init__(self, plugin_service: PluginService) -> None:
self._plugin_service = plugin_service
self._stale_dns_helper = StaleDnsHelper(self._plugin_service)
StaleDnsPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods)
@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS
def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._stale_dns_helper.get_verified_connection(
is_initial_connection, self._host_list_provider_service, host_info, props, connect_func)
def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
try:
self._plugin_service.refresh_host_list()
except Exception:
pass
return execute_func()
def init_host_provider(
self,
properties: Properties,
host_list_provider_service: HostListProviderService,
init_host_provider_func: Callable):
self._host_list_provider_service = host_list_provider_service
init_host_provider_func()
if self._host_list_provider_service.is_static_host_list_provider():
raise Exception(Messages.get_formatted("StaleDnsPlugin.RequireDynamicProvider"))
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
self._stale_dns_helper.notify_host_list_changed(changes)
class StaleDnsPluginFactory(PluginFactory):
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return StaleDnsPlugin(plugin_service)