Skip to content

Commit 75cbd34

Browse files
committed
Add noop async client
1 parent 68f567a commit 75cbd34

5 files changed

Lines changed: 315 additions & 0 deletions

File tree

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pytest
22
pytest-cov
33
pytest-xdist
44
coverage
5+
pytest-asyncio

requirements/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ pluggy==1.5.0
1717
pytest==8.2.2
1818
# via
1919
# -r requirements/test.in
20+
# pytest-asyncio
2021
# pytest-cov
2122
# pytest-xdist
23+
pytest-asyncio==0.23.7
24+
# via -r requirements/test.in
2225
pytest-cov==5.0.0
2326
# via -r requirements/test.in
2427
pytest-xdist==3.6.1

src/statsd/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Version: v\ |version|.
55
"""
66

7+
from .async_client import BaseAsyncStatsdClient, DebugAsyncStatsdClient
78
from .base import Sample
89
from .client import (
910
BaseStatsdClient,
@@ -15,7 +16,9 @@
1516

1617

1718
__all__ = (
19+
"BaseAsyncStatsdClient",
1820
"BaseStatsdClient",
21+
"DebugAsyncStatsdClient",
1922
"DebugStatsdClient",
2023
"Sample",
2124
"StatsdClient",

src/statsd/async_client.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import contextlib
5+
import functools
6+
import logging
7+
import time
8+
from typing import Any, AsyncIterator, Awaitable, Callable, Mapping, TypeVar
9+
from typing_extensions import ParamSpec
10+
11+
from statsd.base import AbstractStatsdClient
12+
13+
14+
P = ParamSpec("P")
15+
T = TypeVar("T")
16+
U = TypeVar("U")
17+
18+
logger = logging.getLogger("statsd")
19+
20+
21+
class BaseAsyncStatsdClient(AbstractStatsdClient[Awaitable[None]]):
22+
"""
23+
Base async client.
24+
25+
This class exposes the public interface and takes care of packet formatting
26+
as well as sampling. It does not actually send packets anywhere, which is
27+
left to concrete subclasses implementing :meth:`_emit`.
28+
"""
29+
30+
@abc.abstractmethod
31+
async def _emit(self, packets: list[str]) -> None:
32+
"""
33+
Async send implementation.
34+
35+
This method is responsible for actually sending the formatted packets
36+
and should be implemented by all subclasses.
37+
38+
It may batch or buffer packets but should not modify them in any way. It
39+
should be agnostic to the Statsd format.
40+
"""
41+
raise NotImplementedError()
42+
43+
def timed(
44+
self,
45+
name: str | None = None,
46+
*,
47+
tags: Mapping[str, str] | None = None,
48+
sample_rate: float | None = None,
49+
use_distribution: bool = False,
50+
) -> Callable[[Callable[P, Awaitable[U]]], Callable[P, Awaitable[U]]]:
51+
"""
52+
Wrap a function to record its execution time.
53+
54+
This just wraps the function call with a :meth:`timer` context manager.
55+
56+
If a name is not provided, the function name will be used.
57+
58+
Passing ``use_distribution=True`` will report the value as a globally
59+
aggregated :meth:`distribution` metric instead of a :meth:`timing`
60+
metric.
61+
62+
>>> client = AsyncStatsdClient()
63+
>>> @client.timed()
64+
... async def do_something():
65+
... pass
66+
"""
67+
68+
def decorator(
69+
fn: Callable[P, Awaitable[U]],
70+
) -> Callable[P, Awaitable[U]]:
71+
# TODO: Should the fallback include the module? Class (for methods)?
72+
# or func.__name__
73+
metric_name = name or fn.__name__
74+
75+
@functools.wraps(fn)
76+
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> U:
77+
async with self.timer(
78+
metric_name,
79+
tags=tags,
80+
use_distribution=use_distribution,
81+
sample_rate=sample_rate,
82+
):
83+
return await fn(*args, **kwargs)
84+
85+
return wrapped
86+
87+
return decorator
88+
89+
@contextlib.asynccontextmanager
90+
async def timer(
91+
self,
92+
name: str,
93+
*,
94+
tags: Mapping[str, str] | None = None,
95+
sample_rate: float | None = None,
96+
use_distribution: bool = False,
97+
) -> AsyncIterator[None]:
98+
"""
99+
Context manager to measure the execution time of an async block.
100+
101+
Passing ``use_distribution=True`` will report the value as a globally
102+
aggregated :meth:`distribution` metric instead of a :meth:`timing`
103+
metric.
104+
105+
>>> client = AsyncStatsdClient()
106+
>>> async def operation():
107+
... async with client.timer("download_duration"):
108+
... pass
109+
"""
110+
start = time.perf_counter()
111+
try:
112+
yield
113+
finally:
114+
duration_ms = int(1000 * (time.perf_counter() - start))
115+
if use_distribution:
116+
await self.distribution(
117+
name,
118+
duration_ms,
119+
tags=tags,
120+
sample_rate=sample_rate,
121+
)
122+
else:
123+
await self.timing(
124+
name,
125+
duration_ms,
126+
tags=tags,
127+
sample_rate=sample_rate,
128+
)
129+
130+
131+
class DebugAsyncStatsdClient(BaseAsyncStatsdClient):
132+
"""
133+
Verbose client for development or debugging purposes.
134+
135+
All Statsd packets will be logged and optionally forwarded to a wrapped
136+
client.
137+
"""
138+
139+
def __init__(
140+
self,
141+
level: int = logging.INFO,
142+
logger: logging.Logger = logger,
143+
inner: BaseAsyncStatsdClient | None = None,
144+
**kwargs: Any,
145+
) -> None:
146+
r"""
147+
Initialize DebugStatsdClient.
148+
149+
:param level: Log level to use, defaults to ``INFO``.
150+
151+
:param logger: Logger instance to use, defaults to ``statsd``.
152+
153+
:param inner: Wrapped client.
154+
155+
:param \**kwargs: Extra arguments forwarded to :class:`BaseAsyncStatsdClient`.
156+
"""
157+
super().__init__(**kwargs)
158+
self.level = level
159+
self.logger = logger
160+
self.inner = inner
161+
162+
async def _emit(self, packets: list[str]) -> None:
163+
for packet in packets:
164+
self.logger.log(self.level, "> %s", packet)
165+
if self.inner:
166+
await self.inner._emit(packets)
167+
168+
169+
AsyncStatsdClient = DebugAsyncStatsdClient

tests/test_base_async_client.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import Any
5+
from unittest import mock
6+
7+
import pytest
8+
9+
from statsd import BaseAsyncStatsdClient, DebugAsyncStatsdClient
10+
11+
12+
class MockClient(BaseAsyncStatsdClient):
13+
def __init__(self, *args: Any, **kwargs: Any) -> None:
14+
super().__init__(*args, **kwargs)
15+
self.mock = mock.Mock()
16+
17+
async def _emit(self, packets: list[str]) -> None:
18+
if packets:
19+
self.mock(packets)
20+
21+
def assert_emitted(self, expected: list[str] | str) -> None:
22+
self.mock.assert_called_once_with(
23+
expected if isinstance(expected, list) else [expected],
24+
)
25+
26+
def assert_did_not_emit(self) -> None:
27+
self.mock.assert_not_called()
28+
29+
30+
@pytest.mark.asyncio()
31+
async def test_timed_decorator() -> None:
32+
client = MockClient()
33+
34+
@client.timed("foo", tags={"foo": "1"})
35+
async def fn() -> None:
36+
pass
37+
38+
with mock.patch(
39+
"time.perf_counter",
40+
side_effect=[7.886838544, 20.181117592],
41+
):
42+
await fn()
43+
44+
client.mock.assert_called_once_with(["foo:12294|ms|#foo:1"])
45+
46+
47+
@pytest.mark.asyncio()
48+
async def test_timed_decorator_use_distribution() -> None:
49+
client = MockClient()
50+
51+
@client.timed("foo", tags={"foo": "1"}, use_distribution=True)
52+
async def fn() -> None:
53+
pass
54+
55+
with mock.patch(
56+
"time.perf_counter",
57+
side_effect=[7.886838544, 20.181117592],
58+
):
59+
await fn()
60+
61+
client.mock.assert_called_once_with(["foo:12294|d|#foo:1"])
62+
63+
64+
SIMPLE_TEST_CASES: list[
65+
tuple[str, tuple[Any, ...], dict[str, Any], list[str] | str]
66+
] = [
67+
("increment", ("foo",), {}, "foo:1|c"),
68+
("increment", ("foo", 10), {}, "foo:10|c"),
69+
]
70+
71+
72+
@pytest.mark.parametrize(
73+
("method", "args", "kwargs", "expected"),
74+
SIMPLE_TEST_CASES,
75+
)
76+
@pytest.mark.asyncio()
77+
async def test_debug_client_no_inner(
78+
method: str,
79+
args: tuple[Any, ...],
80+
kwargs: dict[str, Any],
81+
expected: list[str] | str,
82+
caplog: Any,
83+
) -> None:
84+
client = DebugAsyncStatsdClient()
85+
with caplog.at_level(logging.INFO, logger="statsd"):
86+
await getattr(client, method)(*args, **kwargs)
87+
88+
if isinstance(expected, list):
89+
assert len(caplog.records) == len(expected)
90+
for x in expected:
91+
assert x in caplog.text
92+
else:
93+
assert len(caplog.records) == 1
94+
assert expected in caplog.text
95+
96+
97+
@pytest.mark.parametrize(
98+
("method", "args", "kwargs", "expected"),
99+
SIMPLE_TEST_CASES,
100+
)
101+
@pytest.mark.asyncio()
102+
async def test_debug_client(
103+
method: str,
104+
args: tuple[Any, ...],
105+
kwargs: dict[str, Any],
106+
expected: str | list[str],
107+
) -> None:
108+
inner = MockClient()
109+
client = DebugAsyncStatsdClient(inner=inner)
110+
await getattr(client, method)(*args, **kwargs)
111+
inner.assert_emitted(expected)
112+
113+
114+
@pytest.mark.parametrize(
115+
("method", "args", "kwargs", "expected"),
116+
SIMPLE_TEST_CASES,
117+
)
118+
@pytest.mark.asyncio()
119+
async def test_debug_client_custom_logger_and_level(
120+
method: str,
121+
args: tuple[Any, ...],
122+
kwargs: dict[str, Any],
123+
expected: list[str] | str,
124+
caplog: Any,
125+
) -> None:
126+
client = DebugAsyncStatsdClient(
127+
logger=logging.getLogger("foo"),
128+
level=logging.DEBUG,
129+
)
130+
with caplog.at_level(logging.DEBUG, logger="foo"):
131+
await getattr(client, method)(*args, **kwargs)
132+
133+
if isinstance(expected, list):
134+
assert len(caplog.records) == len(expected)
135+
for x in expected:
136+
assert x in caplog.text
137+
else:
138+
assert len(caplog.records) == 1
139+
assert expected in caplog.text

0 commit comments

Comments
 (0)