Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,26 @@ This is very ineficent and should not be used for high-volume schedules. Because
This source holds values in lists.

* For cron tasks it uses key `{prefix}:cron`.
* For interval tasks it uses key `{prefix}:interval`.
* For timed schedules it uses key `{prefix}:time:{time}` where `{time}` is actually time where schedules should run.
* A sorted set at `{prefix}:time_index` tracks all time keys with their unix timestamps as scores, so that past time schedules can be discovered via `ZRANGEBYSCORE` instead of scanning all Redis keys. Stale entries (older than 5 minutes with empty time key lists) are cleaned up automatically.

The main advantage of this approach is that we only fetch tasks we need to run at a given time and do not perform any excesive calls to redis.
The main advantage of this approach is that we only fetch tasks we need to run at a given time and do not perform any excessive calls to redis.

#### `populate_time_index`

If you are upgrading from an older version that did not maintain the `{prefix}:time_index` sorted set, existing time keys will not be present in the index. Set `populate_time_index=True` once on startup to backfill the index via a one-time `SCAN`, then set it back to `False` for subsequent runs:

```python
# First run after upgrading — backfills the time index
source = ListRedisScheduleSource(
"redis://localhost/1",
populate_time_index=True,
)

# All subsequent runs — no SCAN, uses the time index
source = ListRedisScheduleSource("redis://localhost/1")
```


### Migration from one source to another
Expand Down
142 changes: 116 additions & 26 deletions taskiq_redis/list_schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import time as _time
from logging import getLogger
from typing import Any

Expand All @@ -23,6 +24,7 @@ def __init__(
serializer: TaskiqSerializer | None = None,
buffer_size: int = 50,
skip_past_schedules: bool = False,
populate_time_index: bool = False,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -34,6 +36,11 @@ def __init__(
:param serializer: Serializer to use for the schedules
:param buffer_size: Buffer size for getting schedules
:param skip_past_schedules: Skip schedules that are in the past.
:param populate_time_index: If True, on startup run a one-time SCAN
to populate the time index sorted set from existing time keys.
This is needed for migrating from an older version that did not
maintain the time index. Set this to True once to backfill the
index, then set it back to False for subsequent runs.
:param connection_kwargs: Additional connection kwargs
"""
super().__init__()
Expand All @@ -47,10 +54,11 @@ def __init__(
if serializer is None:
serializer = PickleSerializer()
self._serializer = serializer
self._is_first_run = True
self._previous_schedule_source: ScheduleSource | None = None
self._delete_schedules_after_migration: bool = True
self._skip_past_schedules = skip_past_schedules
self._populate_time_index = populate_time_index
self._last_cleanup_time: float = 0

async def startup(self) -> None:
"""
Expand All @@ -59,6 +67,9 @@ async def startup(self) -> None:
By default this function does nothing.
But if the previous schedule source is set,
it will try to migrate schedules from it.

If populate_time_index is True, it will scan for existing
time keys and populate the time index sorted set.
"""
if self._previous_schedule_source is not None:
logger.info("Migrating schedules from previous source")
Expand All @@ -74,13 +85,36 @@ async def startup(self) -> None:
await self._previous_schedule_source.shutdown()
logger.info("Migration complete")

if self._populate_time_index:
logger.info("Populating time index from existing keys via scan")
async with Redis(connection_pool=self._connection_pool) as redis:
batch: dict[str, float] = {}
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
key_str = key.decode()
key_time = self._parse_time_key(key_str)
if key_time:
batch[key_str] = key_time.timestamp()
if len(batch) >= self._buffer_size:
await redis.zadd(
self._get_time_index_key(),
batch,
)
batch = {}
if batch:
await redis.zadd(self._get_time_index_key(), batch)
logger.info("Time index population complete")

def _get_time_key(self, time: datetime.datetime) -> str:
"""Get the key for a time-based schedule."""
if time.tzinfo is None:
time = time.replace(tzinfo=datetime.timezone.utc)
iso_time = time.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M")
return f"{self._prefix}:time:{iso_time}"

def _get_time_index_key(self) -> str:
"""Get the key for the time index sorted set."""
return f"{self._prefix}:time_index"

def _get_cron_key(self) -> str:
"""Get the key for a cron-based schedule."""
return f"{self._prefix}:cron"
Expand All @@ -103,35 +137,78 @@ def _parse_time_key(self, key: str) -> datetime.datetime | None:
logger.debug("Failed to parse time key %s", key)
return None

async def _get_previous_time_schedules(self) -> list[bytes]:
async def _maybe_cleanup_time_index(self, redis: Redis) -> None: # type: ignore[type-arg]
"""
Run time index cleanup at most once per minute.

Called from delete_schedule after removing a time-based schedule,
since that's the path where time key lists become empty.
"""
now = _time.monotonic()
if now - self._last_cleanup_time < 60:
return
self._last_cleanup_time = now
await self._cleanup_time_index(redis)

async def _cleanup_time_index(self, redis: Redis) -> None: # type: ignore[type-arg]
"""
Remove stale entries from the time index sorted set.

Only removes entries that are older than 5 minutes AND whose
corresponding time key list is empty (or no longer exists).
This avoids a race condition where an eager cleanup in
delete_schedule could remove an index entry right as
add_schedule is creating a new schedule at the same minute.
"""
five_minutes_ago = (
datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(minutes=5)
).timestamp()
stale_keys: list[bytes] = await redis.zrangebyscore(
self._get_time_index_key(),
"-inf",
five_minutes_ago,
)
for key in stale_keys:
if await redis.llen(key) == 0:
await redis.zrem(self._get_time_index_key(), key)

async def _get_previous_time_schedules(
self,
current_time: datetime.datetime,
) -> list[bytes]:
"""
Function that gets all timed schedules that are in the past.

Since this source doesn't retrieve all the schedules at once,
we need to get all the schedules that are in the past and haven't
been sent yet.

We do this by getting all the time keys and checking if the time
is less than the current time.
Uses the time index sorted set to look up past time keys
instead of scanning all Redis keys.

Called on every get_schedules invocation so that schedules
added in a past minute (after the previous get_schedules call
but before the minute rolled over) are never missed.

This function is called only during the first run to minimize
the number of requests to the Redis server.
:param current_time: The reference time captured by the caller,
used to derive the cutoff so that the "previous" and "current"
windows never overlap.
"""
logger.info("Getting previous time schedules")
minute_before = datetime.datetime.now(
datetime.timezone.utc,
).replace(second=0, microsecond=0) - datetime.timedelta(
minute_before = current_time.replace(
second=0, microsecond=0,
) - datetime.timedelta(
minutes=1,
)
schedules = []
async with Redis(connection_pool=self._connection_pool) as redis:
time_keys: list[str] = []
# We need to get all the time keys and check if the time is less than
# the current time.
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
key_time = self._parse_time_key(key.decode())
if key_time and key_time <= minute_before:
time_keys.append(key.decode())
max_score = minute_before.timestamp()
time_keys: list[bytes] = await redis.zrangebyscore(
self._get_time_index_key(),
"-inf",
max_score,
)
for key in time_keys:
schedules.extend(await redis.lrange(key, 0, -1)) # type: ignore[misc]

Expand All @@ -153,6 +230,7 @@ async def delete_schedule(self, schedule_id: str) -> None:
elif schedule.time is not None:
time_key = self._get_time_key(schedule.time)
await redis.lrem(time_key, 0, schedule_id) # type: ignore[misc]
await self._maybe_cleanup_time_index(redis)
elif schedule.interval:
await redis.lrem(self._get_interval_key(), 0, schedule_id) # type: ignore[misc]

Expand All @@ -170,9 +248,21 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None:
if schedule.cron is not None:
await redis.rpush(self._get_cron_key(), schedule.schedule_id) # type: ignore[misc]
elif schedule.time is not None:
await redis.rpush( # type: ignore[misc]
self._get_time_key(schedule.time),
schedule.schedule_id,
time_key = self._get_time_key(schedule.time)
await redis.rpush(time_key, schedule.schedule_id) # type: ignore[misc]
# Add to the time index sorted set so we can look up
# past time keys without scanning all Redis keys.
time_val = schedule.time
if time_val.tzinfo is None:
time_val = time_val.replace(tzinfo=datetime.timezone.utc)
score = (
time_val.astimezone(datetime.timezone.utc)
.replace(second=0, microsecond=0)
.timestamp()
)
await redis.zadd( # type: ignore[misc]
self._get_time_index_key(),
{time_key: score},
)
elif schedule.interval:
await redis.rpush( # type: ignore[misc]
Expand All @@ -190,19 +280,19 @@ async def get_schedules(self) -> list["ScheduledTask"]:
Get all schedules.

This function gets all the schedules from the schedule source.
What it does is get all the cron schedules and time schedules
for the current time and return them.
What it does is get all the cron schedules, interval schedules,
past time schedules, and current-minute time schedules and
return them.

If it's the first run, it also gets all the time schedules
that are in the past and haven't been sent yet.
Past time schedules are fetched on every call so that
schedules added after the previous call but before the
minute rolled over are never missed.
"""
schedules = []
current_time = datetime.datetime.now(datetime.timezone.utc)
timed: list[bytes] = []
# Only during first run, we need to get previous time schedules
if not self._skip_past_schedules:
timed = await self._get_previous_time_schedules()
self._is_first_run = False
timed = await self._get_previous_time_schedules(current_time)
async with Redis(connection_pool=self._connection_pool) as redis:
buffer = []
crons = await redis.lrange(self._get_cron_key(), 0, -1) # type: ignore[misc]
Expand Down
Loading