Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/code/src/main/di/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { FoldersService } from "../services/folders/service";
import { FsService } from "../services/fs/service";
import { GitService } from "../services/git/service";
import { GitHubIntegrationService } from "../services/github-integration/service";
import { HandoffService } from "../services/handoff/service";
import { LinearIntegrationService } from "../services/linear-integration/service";
import { LlmGatewayService } from "../services/llm-gateway/service";
import { McpAppsService } from "../services/mcp-apps/service";
Expand Down Expand Up @@ -88,6 +89,7 @@ container
.bind(MAIN_TOKENS.GitHubIntegrationService)
.to(GitHubIntegrationService);
container.bind(MAIN_TOKENS.GitService).to(GitService);
container.bind(MAIN_TOKENS.HandoffService).to(HandoffService);
container
.bind(MAIN_TOKENS.LinearIntegrationService)
.to(LinearIntegrationService);
Expand Down
1 change: 1 addition & 0 deletions apps/code/src/main/di/tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export const MAIN_TOKENS = Object.freeze({
FoldersService: Symbol.for("Main.FoldersService"),
FsService: Symbol.for("Main.FsService"),
GitService: Symbol.for("Main.GitService"),
HandoffService: Symbol.for("Main.HandoffService"),
GitHubIntegrationService: Symbol.for("Main.GitHubIntegrationService"),
LinearIntegrationService: Symbol.for("Main.LinearIntegrationService"),
DeepLinkService: Symbol.for("Main.DeepLinkService"),
Expand Down
13 changes: 13 additions & 0 deletions apps/code/src/main/services/agent/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,19 @@ When creating pull requests, add the following footer at the end of the PR descr
]);
}

setPendingContext(taskRunId: string, context: string): void {
const session = this.sessions.get(taskRunId);
if (!session) {
log.warn("Session not found for setPendingContext", { taskRunId });
return;
}
session.pendingContext = context;
log.info("Set pending context on session", {
taskRunId,
contextLength: context.length,
});
}

/**
* Notify a session of a context change (CWD moved, detached HEAD, etc).
* Used when focusing/unfocusing worktrees - the agent doesn't need to respawn
Expand Down
340 changes: 340 additions & 0 deletions apps/code/src/main/services/handoff/handoff-saga.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
import type { TreeSnapshotEvent } from "@posthog/agent/types";
import { beforeEach, describe, expect, it, vi } from "vitest";
import type { HandoffSagaDeps, HandoffSagaInput } from "./handoff-saga";
import { HandoffSaga } from "./handoff-saga";

const mockResumeFromLog = vi.hoisted(() => vi.fn());
const mockFormatConversation = vi.hoisted(() => vi.fn());

vi.mock("@posthog/agent/resume", () => ({
resumeFromLog: mockResumeFromLog,
formatConversationForResume: mockFormatConversation,
}));

function createInput(
overrides: Partial<HandoffSagaInput> = {},
): HandoffSagaInput {
return {
taskId: "task-1",
runId: "run-1",
repoPath: "/repo",
apiHost: "https://us.posthog.com",
teamId: 2,
...overrides,
};
}

function createSnapshot(
overrides: Partial<TreeSnapshotEvent> = {},
): TreeSnapshotEvent {
return {
treeHash: "abc123",
baseCommit: "def456",
archiveUrl: "https://s3.example.com/archive.tar.gz",
changes: [{ path: "test.txt", status: "A" }],
timestamp: "2026-04-07T00:00:00Z",
...overrides,
};
}

function createDeps(overrides: Partial<HandoffSagaDeps> = {}): HandoffSagaDeps {
return {
createApiClient: vi.fn().mockReturnValue({
getTaskRun: vi.fn().mockResolvedValue({
log_url: "https://logs.example.com/run-1.ndjson",
}),
}),
applyTreeSnapshot: vi.fn().mockResolvedValue(undefined),
updateWorkspaceMode: vi.fn(),
reconnectSession: vi.fn().mockResolvedValue({
sessionId: "session-1",
channel: "ch-1",
}),
closeCloudRun: vi.fn().mockResolvedValue(undefined),
seedLocalLogs: vi.fn().mockResolvedValue(undefined),
killSession: vi.fn().mockResolvedValue(undefined),
setPendingContext: vi.fn(),
onProgress: vi.fn(),
...overrides,
};
}

describe("HandoffSaga", () => {
beforeEach(() => {
vi.clearAllMocks();
mockFormatConversation.mockReturnValue("conversation summary");
});

it("completes happy path with snapshot", async () => {
const snapshot = createSnapshot();
mockResumeFromLog.mockResolvedValue({
conversation: [
{ role: "user", content: [{ type: "text", text: "hello" }] },
],
latestSnapshot: snapshot,
snapshotApplied: false,
interrupted: false,
logEntryCount: 10,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
const result = await saga.run(createInput());

expect(result.success).toBe(true);
if (!result.success) return;
expect(result.data.sessionId).toBe("session-1");
expect(result.data.snapshotApplied).toBe(true);
expect(result.data.conversationTurns).toBe(1);
});

it("closes cloud run before fetching logs", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
await saga.run(createInput());

expect(deps.closeCloudRun).toHaveBeenCalledWith(
"task-1",
"run-1",
"https://us.posthog.com",
2,
);
const closeOrder = (deps.closeCloudRun as ReturnType<typeof vi.fn>).mock
.invocationCallOrder[0];
const fetchOrder = mockResumeFromLog.mock.invocationCallOrder[0];
expect(closeOrder).toBeLessThan(fetchOrder);
});

it("skips snapshot apply when no archiveUrl", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: createSnapshot({ archiveUrl: undefined }),
snapshotApplied: false,
interrupted: false,
logEntryCount: 5,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
const result = await saga.run(createInput());

expect(result.success).toBe(true);
if (!result.success) return;
expect(result.data.snapshotApplied).toBe(false);
expect(deps.applyTreeSnapshot).not.toHaveBeenCalled();
});

it("skips snapshot apply when no snapshot at all", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
const result = await saga.run(createInput());

expect(result.success).toBe(true);
if (!result.success) return;
expect(result.data.snapshotApplied).toBe(false);
expect(deps.applyTreeSnapshot).not.toHaveBeenCalled();
});

it("seeds local logs when cloudLogUrl is present", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
await saga.run(createInput());

expect(deps.seedLocalLogs).toHaveBeenCalledWith(
"run-1",
"https://logs.example.com/run-1.ndjson",
);
});

it("skips seeding logs when cloudLogUrl is falsy", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const apiClient = {
getTaskRun: vi.fn().mockResolvedValue({ log_url: undefined }),
};
const deps = createDeps({
createApiClient: vi.fn().mockReturnValue(apiClient),
});
const saga = new HandoffSaga(deps);
await saga.run(createInput());

expect(deps.seedLocalLogs).not.toHaveBeenCalled();
});

it("sets pending context with handoff summary", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [
{ role: "user", content: [{ type: "text", text: "hello" }] },
],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 1,
});
mockFormatConversation.mockReturnValue("User said hello");

const deps = createDeps();
const saga = new HandoffSaga(deps);
await saga.run(createInput());

expect(deps.setPendingContext).toHaveBeenCalledWith(
"run-1",
expect.stringContaining("resuming a previous conversation"),
);
expect(deps.setPendingContext).toHaveBeenCalledWith(
"run-1",
expect.stringContaining("could not be restored"),
);
});

it("context mentions files restored when snapshot applied", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: createSnapshot(),
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
await saga.run(createInput());

expect(deps.setPendingContext).toHaveBeenCalledWith(
"run-1",
expect.stringContaining("fully restored"),
);
});

it("passes sessionId and adapter through to reconnectSession", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
await saga.run(createInput({ sessionId: "ses-abc", adapter: "codex" }));

expect(deps.reconnectSession).toHaveBeenCalledWith(
expect.objectContaining({
sessionId: "ses-abc",
adapter: "codex",
}),
);
});

it("emits progress events in order", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: createSnapshot(),
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps();
const saga = new HandoffSaga(deps);
await saga.run(createInput());

const progressCalls = (deps.onProgress as ReturnType<typeof vi.fn>).mock
.calls;
const steps = progressCalls.map((call: unknown[]) => call[0]);
expect(steps).toEqual([
"fetching_logs",
"applying_snapshot",
"spawning_agent",
"complete",
]);
});

describe("rollbacks", () => {
it("rolls back workspace mode when spawn_agent fails", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps({
reconnectSession: vi.fn().mockRejectedValue(new Error("spawn failed")),
});
const saga = new HandoffSaga(deps);
const result = await saga.run(createInput());

expect(result.success).toBe(false);
if (result.success) return;
expect(result.failedStep).toBe("spawn_agent");
expect(deps.updateWorkspaceMode).toHaveBeenCalledWith("task-1", "cloud");
});

it("kills session on rollback if spawn partially succeeded", async () => {
mockResumeFromLog.mockResolvedValue({
conversation: [],
latestSnapshot: null,
snapshotApplied: false,
interrupted: false,
logEntryCount: 0,
});

const deps = createDeps({
reconnectSession: vi.fn().mockResolvedValue(null),
});
const saga = new HandoffSaga(deps);
const result = await saga.run(createInput());

expect(result.success).toBe(false);
if (result.success) return;
expect(result.failedStep).toBe("spawn_agent");
});

it("fails at fetch_and_rebuild without rolling back workspace", async () => {
mockResumeFromLog.mockRejectedValue(new Error("API down"));

const deps = createDeps();
const saga = new HandoffSaga(deps);
const result = await saga.run(createInput());

expect(result.success).toBe(false);
if (result.success) return;
expect(result.failedStep).toBe("fetch_and_rebuild");
expect(deps.updateWorkspaceMode).not.toHaveBeenCalled();
expect(deps.reconnectSession).not.toHaveBeenCalled();
});
});
});
Loading
Loading