Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions src/core/discord_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import asyncio
import logging

from core.config import settings
from core.discord_fetcher import fetch_channel_messages
from core.ingestion import (
_docs_from_discord_messages,
delete_message_from_store,
ingest_documents,
update_message_in_store,
)
from core.state import update_last_ingested_message_id
from core.state import get_last_ingested_message_id, update_last_ingested_message_id

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +34,17 @@ async def handle_message_create(message: dict, channel_ids: set[str]) -> None:
if not docs:
return

count = await asyncio.to_thread(ingest_documents, docs)
try:
count = await asyncio.to_thread(ingest_documents, docs)
except Exception as e:
logger.error(
"[listener] failed to ingest message_id=%s channel=%s: %s",
message["id"],
channel_id,
e,
)
return

if count:
update_last_ingested_message_id(channel_id, message["id"])
logger.info("[listener] ingested %d doc(s) from channel %s", count, channel_id)
Expand All @@ -44,6 +56,8 @@ async def handle_message_delete(payload: dict, channel_ids: set[str]) -> None:

Expected keys: channel_id, id (message_id).
Deletes the corresponding Qdrant point(s) if the channel is watched.
If the deleted message was the ingestion cursor, advances the cursor to the
most recent remaining message to prevent data loss on restart.
Exposed at module level so unit tests can call it without a real Discord client.
"""
channel_id = payload.get("channel_id", "")
Expand All @@ -62,6 +76,35 @@ async def handle_message_delete(payload: dict, channel_ids: set[str]) -> None:
channel_id,
)

# Check if the deleted message was the ingestion cursor and update if needed
# to prevent data loss when messages are deleted and bot restarts
current_cursor = get_last_ingested_message_id(channel_id)
if current_cursor and str(current_cursor) == str(message_id):
logger.info(
"[listener] deleted message was cursor, finding new cursor for channel=%s",
channel_id,
)
if settings.DISCORD_BOT_TOKEN:
recent = await fetch_channel_messages(
bot_token=settings.DISCORD_BOT_TOKEN,
channel_id=channel_id,
limit=1,
)
if recent:
new_cursor = recent[0]["id"]
update_last_ingested_message_id(channel_id, new_cursor)
logger.info(
"[listener] advanced cursor from %s to %s for channel=%s",
message_id,
new_cursor,
channel_id,
)
else:
logger.warning(
"[listener] deleted cursor message but no remaining messages in channel=%s",
channel_id,
)


async def handle_message_update(payload: dict, channel_ids: set[str]) -> None:
"""
Expand Down
147 changes: 146 additions & 1 deletion tests/unit/test_discord_listener.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Unit tests for discord_listener.handle_message_create.
Unit tests for discord_listener.handle_message_create and handle_message_delete.

We test the dict-based handler in isolation — no real Discord connection or
Qdrant needed. The discord.py import only happens inside run_discord_listener(),
Expand Down Expand Up @@ -101,6 +101,30 @@ def test_handle_message_create_skips_blank_content(monkeypatch):
assert state_updates == []


def test_handle_message_create_handles_ingestion_error(monkeypatch, caplog):
"""Ingestion failures must not update cursor and should log the error."""
import logging

state_updates = []

def fake_ingest(docs):
raise RuntimeError("Qdrant connection timeout")

def fake_update_cursor(channel_id, msg_id):
state_updates.append((channel_id, msg_id))

monkeypatch.setattr(discord_listener, "ingest_documents", fake_ingest)
monkeypatch.setattr(
discord_listener, "update_last_ingested_message_id", fake_update_cursor
)

with caplog.at_level(logging.ERROR, logger="core.discord_listener"):
asyncio.run(discord_listener.handle_message_create(VALID_MESSAGE, WATCHED))

assert state_updates == [], "Cursor must not be updated on ingestion failure"
assert any("failed to ingest" in r.message for r in caplog.records)


# ---------------------------------------------------------------------------
# _message_to_dict — shape contract
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -156,3 +180,124 @@ def test_run_discord_listener_returns_early_without_channels(caplog):
with caplog.at_level(logging.ERROR, logger="core.discord_listener"):
asyncio.run(discord_listener.run_discord_listener("some-token", []))
assert any("No channel IDs" in r.message for r in caplog.records)


# ---------------------------------------------------------------------------
# handle_message_delete — cursor handling
# ---------------------------------------------------------------------------


def test_handle_message_delete_removes_point(monkeypatch):
"""Deleting a message must remove the corresponding Qdrant point."""
deleted = []

def fake_delete(channel_id, message_id):
deleted.append((channel_id, message_id))
return 1

monkeypatch.setattr(discord_listener, "delete_message_from_store", fake_delete)
monkeypatch.setattr(
discord_listener, "get_last_ingested_message_id", lambda ch: None
)

asyncio.run(
discord_listener.handle_message_delete(
{"channel_id": "111111111111111111", "id": "999"},
WATCHED,
)
)

assert deleted == [("111111111111111111", "999")]


def test_handle_message_delete_ignores_wrong_channel(monkeypatch):
"""Messages in unwatched channels must not trigger deletion."""
called = []
monkeypatch.setattr(
discord_listener, "delete_message_from_store", lambda *a: called.append(a) or 0
)

asyncio.run(
discord_listener.handle_message_delete(
{"channel_id": "999999999999999999", "id": "999"},
WATCHED,
)
)

assert called == []


def test_handle_message_delete_updates_cursor_when_deleted_message_was_cursor(
monkeypatch,
):
"""When the deleted message was the cursor, cursor must advance to most recent."""
deleted = []
cursor_updates = []
fetched_messages = []

def fake_delete(channel_id, message_id):
deleted.append((channel_id, message_id))
return 1

def fake_get_cursor(channel_id):
return "999" # The message being deleted is the cursor

def fake_update_cursor(channel_id, msg_id):
cursor_updates.append((channel_id, msg_id))

async def fake_fetch(bot_token, channel_id, limit):
fetched_messages.append((bot_token, channel_id, limit))
return [{"id": "1000", "content": "newest message"}]

monkeypatch.setattr(discord_listener, "delete_message_from_store", fake_delete)
monkeypatch.setattr(discord_listener, "get_last_ingested_message_id", fake_get_cursor)
monkeypatch.setattr(
discord_listener, "update_last_ingested_message_id", fake_update_cursor
)
monkeypatch.setattr(discord_listener, "fetch_channel_messages", fake_fetch)
monkeypatch.setattr(discord_listener.settings, "DISCORD_BOT_TOKEN", "fake-token")

asyncio.run(
discord_listener.handle_message_delete(
{"channel_id": "111111111111111111", "id": "999"},
WATCHED,
)
)

assert deleted == [("111111111111111111", "999")]
assert cursor_updates == [("111111111111111111", "1000")]
assert fetched_messages == [("fake-token", "111111111111111111", 1)]


def test_handle_message_delete_does_not_update_cursor_when_not_cursor(
monkeypatch,
):
"""When the deleted message was NOT the cursor, cursor must not change."""
deleted = []
cursor_updates = []

def fake_delete(channel_id, message_id):
deleted.append((channel_id, message_id))
return 1

def fake_get_cursor(channel_id):
return "888" # Different from the message being deleted

def fake_update_cursor(channel_id, msg_id):
cursor_updates.append((channel_id, msg_id))

monkeypatch.setattr(discord_listener, "delete_message_from_store", fake_delete)
monkeypatch.setattr(discord_listener, "get_last_ingested_message_id", fake_get_cursor)
monkeypatch.setattr(
discord_listener, "update_last_ingested_message_id", fake_update_cursor
)

asyncio.run(
discord_listener.handle_message_delete(
{"channel_id": "111111111111111111", "id": "999"},
WATCHED,
)
)

assert deleted == [("111111111111111111", "999")]
assert cursor_updates == []
Loading