From cb4c9f0c2580114578a5797a2ad83d9ef30d3813 Mon Sep 17 00:00:00 2001 From: JonathanLab Date: Fri, 10 Apr 2026 14:04:42 +0200 Subject: [PATCH] feat: use checkpoints for handoff --- apps/code/src/main/services/agent/service.ts | 4 +- .../services/handoff/handoff-saga.test.ts | 283 ++++---- .../src/main/services/handoff/handoff-saga.ts | 57 +- .../code/src/main/services/handoff/schemas.ts | 4 + .../src/main/services/handoff/service.test.ts | 24 + .../code/src/main/services/handoff/service.ts | 73 +- .../features/sessions/service/service.ts | 22 +- .../task-detail/hooks/useTaskCreation.ts | 3 +- packages/agent/package.json | 8 + packages/agent/src/acp-extensions.ts | 3 + packages/agent/src/handoff-checkpoint.test.ts | 183 +++++ packages/agent/src/handoff-checkpoint.ts | 361 ++++++++++ packages/agent/src/posthog-api.test.ts | 29 + packages/agent/src/posthog-api.ts | 6 +- packages/agent/src/resume.ts | 8 +- .../agent/src/sagas/apply-snapshot-saga.ts | 7 + packages/agent/src/sagas/capture-tree-saga.ts | 13 +- packages/agent/src/sagas/resume-saga.ts | 32 + packages/agent/src/sagas/test-fixtures.ts | 46 ++ packages/agent/src/server/agent-server.ts | 75 +- packages/agent/src/server/schemas.ts | 23 +- packages/agent/src/types.ts | 24 + packages/agent/tsup.config.ts | 2 + packages/git/src/handoff.test.ts | 391 +++++++++++ packages/git/src/handoff.ts | 639 ++++++++++++++++++ packages/git/src/sagas/checkpoint.ts | 1 + 26 files changed, 2153 insertions(+), 168 deletions(-) create mode 100644 packages/agent/src/handoff-checkpoint.test.ts create mode 100644 packages/agent/src/handoff-checkpoint.ts create mode 100644 packages/git/src/handoff.test.ts create mode 100644 packages/git/src/handoff.ts diff --git a/apps/code/src/main/services/agent/service.ts b/apps/code/src/main/services/agent/service.ts index 45fc6a26f..f881154f0 100644 --- a/apps/code/src/main/services/agent/service.ts +++ b/apps/code/src/main/services/agent/service.ts @@ -34,7 +34,7 @@ import { isOpenAIModel, } from "@posthog/agent/gateway-models"; import { getLlmGatewayUrl } from "@posthog/agent/posthog-api"; -import type { OnLogCallback } from "@posthog/agent/types"; +import type * as AgentTypes from "@posthog/agent/types"; import { getCurrentBranch } from "@posthog/git/queries"; import { isAuthError } from "@shared/errors"; import type { AcpMessage } from "@shared/types/session-events"; @@ -177,7 +177,7 @@ function createTappedWritableStream( }); } -const onAgentLog: OnLogCallback = (level, scope, message, data) => { +const onAgentLog: AgentTypes.OnLogCallback = (level, scope, message, data) => { const scopedLog = logger.scope(scope); if (data !== undefined) { scopedLog[level as keyof typeof scopedLog](message, data); diff --git a/apps/code/src/main/services/handoff/handoff-saga.test.ts b/apps/code/src/main/services/handoff/handoff-saga.test.ts index 9d89332de..0137b2b4c 100644 --- a/apps/code/src/main/services/handoff/handoff-saga.test.ts +++ b/apps/code/src/main/services/handoff/handoff-saga.test.ts @@ -1,4 +1,5 @@ -import type { TreeSnapshotEvent } from "@posthog/agent/types"; +import type * as AgentResume from "@posthog/agent/resume"; +import type * as AgentTypes from "@posthog/agent/types"; import { beforeEach, describe, expect, it, vi } from "vitest"; import type { HandoffSagaDeps, HandoffSagaInput } from "./handoff-saga"; import { HandoffSaga } from "./handoff-saga"; @@ -6,6 +7,14 @@ import { HandoffSaga } from "./handoff-saga"; const mockResumeFromLog = vi.hoisted(() => vi.fn()); const mockFormatConversation = vi.hoisted(() => vi.fn()); +const DEFAULT_LOCAL_GIT_STATE = { + head: "abc123", + branch: "feature/handoff", + upstreamHead: null, + upstreamRemote: "origin", + upstreamMergeRef: "refs/heads/feature/handoff", +}; + vi.mock("@posthog/agent/resume", () => ({ resumeFromLog: mockResumeFromLog, formatConversationForResume: mockFormatConversation, @@ -25,8 +34,8 @@ function createInput( } function createSnapshot( - overrides: Partial = {}, -): TreeSnapshotEvent { + overrides: Partial = {}, +): AgentTypes.TreeSnapshotEvent { return { treeHash: "abc123", baseCommit: "def456", @@ -37,6 +46,27 @@ function createSnapshot( }; } +function createCheckpoint( + overrides: Partial = {}, +): AgentTypes.GitCheckpointEvent { + return { + checkpointId: "checkpoint-1", + commit: "checkpointcommit123", + checkpointRef: "refs/posthog-code-checkpoint/checkpoint-1", + headRef: "refs/posthog-code-handoff/head/checkpoint-1", + head: "def456", + branch: "feature/handoff", + indexTree: "index123", + worktreeTree: "worktree123", + artifactPath: "gs://bucket/checkpoint-1.bundle", + timestamp: "2026-04-07T00:00:00Z", + upstreamRemote: "origin", + upstreamMergeRef: "refs/heads/feature/handoff", + remoteUrl: "git@github.com:PostHog/code.git", + ...overrides, + }; +} + function createDeps(overrides: Partial = {}): HandoffSagaDeps { return { createApiClient: vi.fn().mockReturnValue({ @@ -45,6 +75,7 @@ function createDeps(overrides: Partial = {}): HandoffSagaDeps { }), }), applyTreeSnapshot: vi.fn().mockResolvedValue(undefined), + applyGitCheckpoint: vi.fn().mockResolvedValue(undefined), updateWorkspaceMode: vi.fn(), reconnectSession: vi.fn().mockResolvedValue({ sessionId: "session-1", @@ -59,6 +90,40 @@ function createDeps(overrides: Partial = {}): HandoffSagaDeps { }; } +function createResumeState( + overrides: Partial = {}, +): AgentResume.ResumeState { + return { + conversation: [], + latestSnapshot: null, + latestGitCheckpoint: null, + snapshotApplied: false, + interrupted: false, + logEntryCount: 0, + ...overrides, + }; +} + +function getProgressSteps(deps: HandoffSagaDeps): string[] { + return (deps.onProgress as ReturnType).mock.calls.map( + (call: unknown[]) => call[0] as string, + ); +} + +async function runSaga( + overrides: { + input?: Partial; + deps?: Partial; + resumeState?: Partial; + } = {}, +) { + mockResumeFromLog.mockResolvedValue(createResumeState(overrides.resumeState)); + const deps = createDeps(overrides.deps); + const saga = new HandoffSaga(deps); + const result = await saga.run(createInput(overrides.input)); + return { deps, result }; +} + describe("HandoffSaga", () => { beforeEach(() => { vi.clearAllMocks(); @@ -67,20 +132,16 @@ describe("HandoffSaga", () => { it("completes happy path with snapshot", async () => { const snapshot = createSnapshot(); - mockResumeFromLog.mockResolvedValue({ - conversation: [ - { role: "user", content: [{ type: "text", text: "hello" }] }, - ], - latestSnapshot: snapshot, - snapshotApplied: false, - interrupted: false, - logEntryCount: 10, + const { result } = await runSaga({ + resumeState: { + conversation: [ + { role: "user", content: [{ type: "text", text: "hello" }] }, + ], + latestSnapshot: snapshot, + logEntryCount: 10, + }, }); - const deps = createDeps(); - const saga = new HandoffSaga(deps); - const result = await saga.run(createInput()); - expect(result.success).toBe(true); if (!result.success) return; expect(result.data.sessionId).toBe("session-1"); @@ -89,23 +150,14 @@ describe("HandoffSaga", () => { }); it("closes cloud run before fetching logs", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, - }); - - const deps = createDeps(); - const saga = new HandoffSaga(deps); - await saga.run(createInput()); + const { deps } = await runSaga(); expect(deps.closeCloudRun).toHaveBeenCalledWith( "task-1", "run-1", "https://us.posthog.com", 2, + undefined, ); const closeOrder = (deps.closeCloudRun as ReturnType).mock .invocationCallOrder[0]; @@ -114,18 +166,13 @@ describe("HandoffSaga", () => { }); it("skips snapshot apply when no archiveUrl", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: createSnapshot({ archiveUrl: undefined }), - snapshotApplied: false, - interrupted: false, - logEntryCount: 5, + const { deps, result } = await runSaga({ + resumeState: { + latestSnapshot: createSnapshot({ archiveUrl: undefined }), + logEntryCount: 5, + }, }); - const deps = createDeps(); - const saga = new HandoffSaga(deps); - const result = await saga.run(createInput()); - expect(result.success).toBe(true); if (!result.success) return; expect(result.data.snapshotApplied).toBe(false); @@ -133,17 +180,7 @@ describe("HandoffSaga", () => { }); it("skips snapshot apply when no snapshot at all", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, - }); - - const deps = createDeps(); - const saga = new HandoffSaga(deps); - const result = await saga.run(createInput()); + const { deps, result } = await runSaga(); expect(result.success).toBe(true); if (!result.success) return; @@ -152,17 +189,7 @@ describe("HandoffSaga", () => { }); it("seeds local logs when cloudLogUrl is present", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, - }); - - const deps = createDeps(); - const saga = new HandoffSaga(deps); - await saga.run(createInput()); + const { deps } = await runSaga(); expect(deps.seedLocalLogs).toHaveBeenCalledWith( "run-1", @@ -171,41 +198,29 @@ describe("HandoffSaga", () => { }); it("skips seeding logs when cloudLogUrl is falsy", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, - }); - const apiClient = { getTaskRun: vi.fn().mockResolvedValue({ log_url: undefined }), }; - const deps = createDeps({ - createApiClient: vi.fn().mockReturnValue(apiClient), + const { deps } = await runSaga({ + deps: { + createApiClient: vi.fn().mockReturnValue(apiClient), + }, }); - const saga = new HandoffSaga(deps); - await saga.run(createInput()); expect(deps.seedLocalLogs).not.toHaveBeenCalled(); }); it("sets pending context with handoff summary", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [ - { role: "user", content: [{ type: "text", text: "hello" }] }, - ], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 1, - }); mockFormatConversation.mockReturnValue("User said hello"); - const deps = createDeps(); - const saga = new HandoffSaga(deps); - await saga.run(createInput()); + const { deps } = await runSaga({ + resumeState: { + conversation: [ + { role: "user", content: [{ type: "text", text: "hello" }] }, + ], + logEntryCount: 1, + }, + }); expect(deps.setPendingContext).toHaveBeenCalledWith( "run-1", @@ -218,18 +233,12 @@ describe("HandoffSaga", () => { }); it("context mentions files restored when snapshot applied", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: createSnapshot(), - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, + const { deps } = await runSaga({ + resumeState: { + latestSnapshot: createSnapshot(), + }, }); - const deps = createDeps(); - const saga = new HandoffSaga(deps); - await saga.run(createInput()); - expect(deps.setPendingContext).toHaveBeenCalledWith( "run-1", expect.stringContaining("fully restored"), @@ -237,18 +246,10 @@ describe("HandoffSaga", () => { }); it("passes sessionId and adapter through to reconnectSession", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, + const { deps } = await runSaga({ + input: { sessionId: "ses-abc", adapter: "codex" }, }); - const deps = createDeps(); - const saga = new HandoffSaga(deps); - await saga.run(createInput({ sessionId: "ses-abc", adapter: "codex" })); - expect(deps.reconnectSession).toHaveBeenCalledWith( expect.objectContaining({ sessionId: "ses-abc", @@ -258,23 +259,16 @@ describe("HandoffSaga", () => { }); it("emits progress events in order", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: createSnapshot(), - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, + const { deps } = await runSaga({ + resumeState: { + latestGitCheckpoint: createCheckpoint(), + latestSnapshot: createSnapshot(), + }, }); - const deps = createDeps(); - const saga = new HandoffSaga(deps); - await saga.run(createInput()); - - const progressCalls = (deps.onProgress as ReturnType).mock - .calls; - const steps = progressCalls.map((call: unknown[]) => call[0]); - expect(steps).toEqual([ + expect(getProgressSteps(deps)).toEqual([ "fetching_logs", + "applying_git_checkpoint", "applying_snapshot", "spawning_agent", "complete", @@ -283,19 +277,13 @@ describe("HandoffSaga", () => { describe("rollbacks", () => { it("rolls back workspace mode when spawn_agent fails", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, - }); - - const deps = createDeps({ - reconnectSession: vi.fn().mockRejectedValue(new Error("spawn failed")), + const { deps, result } = await runSaga({ + deps: { + reconnectSession: vi + .fn() + .mockRejectedValue(new Error("spawn failed")), + }, }); - const saga = new HandoffSaga(deps); - const result = await saga.run(createInput()); expect(result.success).toBe(false); if (result.success) return; @@ -304,19 +292,11 @@ describe("HandoffSaga", () => { }); it("kills session on rollback if spawn partially succeeded", async () => { - mockResumeFromLog.mockResolvedValue({ - conversation: [], - latestSnapshot: null, - snapshotApplied: false, - interrupted: false, - logEntryCount: 0, - }); - - const deps = createDeps({ - reconnectSession: vi.fn().mockResolvedValue(null), + const { result } = await runSaga({ + deps: { + reconnectSession: vi.fn().mockResolvedValue(null), + }, }); - const saga = new HandoffSaga(deps); - const result = await saga.run(createInput()); expect(result.success).toBe(false); if (result.success) return; @@ -337,4 +317,27 @@ describe("HandoffSaga", () => { expect(deps.reconnectSession).not.toHaveBeenCalled(); }); }); + + it("applies git checkpoint before restoring the file snapshot", async () => { + const { deps, result } = await runSaga({ + input: { localGitState: DEFAULT_LOCAL_GIT_STATE }, + resumeState: { + latestSnapshot: createSnapshot(), + latestGitCheckpoint: createCheckpoint(), + }, + }); + + expect(result.success).toBe(true); + if (!result.success) return; + expect(deps.applyGitCheckpoint).toHaveBeenCalledTimes(1); + expect(deps.applyGitCheckpoint).toHaveBeenCalledWith( + expect.any(Object), + "/repo", + "task-1", + "run-1", + expect.any(Object), + DEFAULT_LOCAL_GIT_STATE, + ); + expect(deps.applyTreeSnapshot).toHaveBeenCalledTimes(1); + }); }); diff --git a/apps/code/src/main/services/handoff/handoff-saga.ts b/apps/code/src/main/services/handoff/handoff-saga.ts index b0bbba45b..f39e74f76 100644 --- a/apps/code/src/main/services/handoff/handoff-saga.ts +++ b/apps/code/src/main/services/handoff/handoff-saga.ts @@ -1,10 +1,10 @@ import type { PostHogAPIClient } from "@posthog/agent/posthog-api"; +import type * as AgentResume from "@posthog/agent/resume"; import { - type ConversationTurn, formatConversationForResume, resumeFromLog, } from "@posthog/agent/resume"; -import type { TreeSnapshotEvent } from "@posthog/agent/types"; +import type * as AgentTypes from "@posthog/agent/types"; import { Saga, type SagaLogger } from "@posthog/shared"; import type { WorkspaceMode } from "../../db/repositories/workspace-repository"; import type { SessionResponse } from "../agent/schemas"; @@ -18,6 +18,7 @@ export interface HandoffSagaInput { teamId: number; sessionId?: string; adapter?: "claude" | "codex"; + localGitState?: AgentTypes.HandoffLocalGitState; } export interface HandoffSagaOutput { @@ -29,12 +30,20 @@ export interface HandoffSagaOutput { export interface HandoffSagaDeps { createApiClient(apiHost: string, teamId: number): PostHogAPIClient; applyTreeSnapshot( - snapshot: TreeSnapshotEvent, + snapshot: AgentTypes.TreeSnapshotEvent, repoPath: string, taskId: string, runId: string, apiClient: PostHogAPIClient, ): Promise; + applyGitCheckpoint( + checkpoint: AgentTypes.GitCheckpointEvent, + repoPath: string, + taskId: string, + runId: string, + apiClient: PostHogAPIClient, + localGitState?: AgentTypes.HandoffLocalGitState, + ): Promise; updateWorkspaceMode(taskId: string, mode: WorkspaceMode): void; reconnectSession(params: { taskId: string; @@ -51,6 +60,7 @@ export interface HandoffSagaDeps { runId: string, apiHost: string, teamId: number, + localGitState?: AgentTypes.HandoffLocalGitState, ): Promise; seedLocalLogs(runId: string, logUrl: string): Promise; killSession(taskRunId: string): Promise; @@ -76,7 +86,13 @@ export class HandoffSaga extends Saga { ); await this.readOnlyStep("close_cloud_run", async () => { - await this.deps.closeCloudRun(taskId, runId, apiHost, teamId); + await this.deps.closeCloudRun( + taskId, + runId, + apiHost, + teamId, + input.localGitState, + ); }); const apiClient = this.deps.createApiClient(apiHost, teamId); @@ -94,7 +110,30 @@ export class HandoffSaga extends Saga { }, ); - let snapshotApplied = false; + let filesRestored = false; + const checkpoint = resumeState.latestGitCheckpoint; + if (checkpoint) { + this.deps.onProgress( + "applying_git_checkpoint", + "Applying cloud git state locally...", + ); + + await this.step({ + name: "apply_git_checkpoint", + execute: async () => { + await this.deps.applyGitCheckpoint( + checkpoint, + repoPath, + taskId, + runId, + apiClient, + input.localGitState, + ); + }, + rollback: async () => {}, + }); + } + const snapshot = resumeState.latestSnapshot; if (snapshot?.archiveUrl) { this.deps.onProgress( @@ -112,7 +151,7 @@ export class HandoffSaga extends Saga { runId, apiClient, ); - snapshotApplied = true; + filesRestored = true; }, rollback: async () => {}, }); @@ -167,7 +206,7 @@ export class HandoffSaga extends Saga { await this.readOnlyStep("set_context", async () => { const context = this.buildHandoffContext( resumeState.conversation, - snapshotApplied, + filesRestored, ); this.deps.setPendingContext(runId, context); }); @@ -176,13 +215,13 @@ export class HandoffSaga extends Saga { return { sessionId: agentSessionId, - snapshotApplied, + snapshotApplied: filesRestored, conversationTurns: resumeState.conversation.length, }; } private buildHandoffContext( - conversation: ConversationTurn[], + conversation: AgentResume.ConversationTurn[], snapshotApplied: boolean, ): string { const conversationSummary = formatConversationForResume(conversation); diff --git a/apps/code/src/main/services/handoff/schemas.ts b/apps/code/src/main/services/handoff/schemas.ts index 3cf46cf56..a9c571cee 100644 --- a/apps/code/src/main/services/handoff/schemas.ts +++ b/apps/code/src/main/services/handoff/schemas.ts @@ -1,3 +1,4 @@ +import { handoffLocalGitStateSchema } from "@posthog/agent/server/schemas"; import { z } from "zod"; export const handoffPreflightInput = z.object({ @@ -14,6 +15,7 @@ export const handoffPreflightResult = z.object({ canHandoff: z.boolean(), reason: z.string().optional(), localTreeDirty: z.boolean(), + localGitState: handoffLocalGitStateSchema.optional(), }); export type HandoffPreflightResult = z.infer; @@ -26,6 +28,7 @@ export const handoffExecuteInput = z.object({ teamId: z.number(), sessionId: z.string().optional(), adapter: z.enum(["claude", "codex"]).optional(), + localGitState: handoffLocalGitStateSchema.optional(), }); export type HandoffExecuteInput = z.infer; @@ -40,6 +43,7 @@ export type HandoffExecuteResult = z.infer; export type HandoffStep = | "fetching_logs" + | "applying_git_checkpoint" | "applying_snapshot" | "updating_run" | "spawning_agent" diff --git a/apps/code/src/main/services/handoff/service.test.ts b/apps/code/src/main/services/handoff/service.test.ts index c529af4e2..8ed90242f 100644 --- a/apps/code/src/main/services/handoff/service.test.ts +++ b/apps/code/src/main/services/handoff/service.test.ts @@ -8,6 +8,9 @@ const mockSendCommand = vi.hoisted(() => vi.fn()); const mockCreatePosthogConfig = vi.hoisted(() => vi.fn()); const mockUpdateMode = vi.hoisted(() => vi.fn()); const mockNetFetch = vi.hoisted(() => vi.fn()); +const mockShowMessageBox = vi.hoisted(() => vi.fn()); +const mockApplyFromHandoff = vi.hoisted(() => vi.fn()); +const mockReadHandoffLocalGitState = vi.hoisted(() => vi.fn()); vi.mock("@main/utils/logger", () => ({ logger: { @@ -34,6 +37,7 @@ vi.mock("inversify", () => ({ vi.mock("electron", () => ({ app: { getPath: () => "/home" }, net: { fetch: mockNetFetch }, + dialog: { showMessageBox: mockShowMessageBox }, })); vi.mock("@posthog/agent/posthog-api", () => ({ @@ -46,6 +50,16 @@ vi.mock("@posthog/agent/tree-tracker", () => ({ })), })); +vi.mock("@posthog/agent/handoff-checkpoint", () => ({ + HandoffCheckpointTracker: vi.fn().mockImplementation(() => ({ + applyFromHandoff: mockApplyFromHandoff, + })), +})); + +vi.mock("@posthog/git/handoff", () => ({ + readHandoffLocalGitState: mockReadHandoffLocalGitState, +})); + vi.mock("@main/di/tokens", () => ({ MAIN_TOKENS: { GitService: Symbol("GitService"), @@ -59,6 +73,14 @@ vi.mock("@main/di/tokens", () => ({ import type { HandoffPreflightInput } from "./schemas"; import { HandoffService } from "./service"; +const DEFAULT_LOCAL_GIT_STATE = { + head: "abc123", + branch: "main", + upstreamHead: "def456", + upstreamRemote: "origin", + upstreamMergeRef: "refs/heads/main", +}; + function createService(): HandoffService { const gitService = { getChangedFilesHead: mockGetChangedFilesHead } as never; const agentService = { @@ -97,6 +119,7 @@ function createPreflightInput( describe("HandoffService.preflight", () => { beforeEach(() => { vi.clearAllMocks(); + mockReadHandoffLocalGitState.mockResolvedValue(DEFAULT_LOCAL_GIT_STATE); }); it("returns canHandoff=true when working tree is clean", async () => { @@ -108,6 +131,7 @@ describe("HandoffService.preflight", () => { expect(result.canHandoff).toBe(true); expect(result.localTreeDirty).toBe(false); expect(result.reason).toBeUndefined(); + expect(result.localGitState).toEqual(DEFAULT_LOCAL_GIT_STATE); }); it("returns canHandoff=false when working tree has changes", async () => { diff --git a/apps/code/src/main/services/handoff/service.ts b/apps/code/src/main/services/handoff/service.ts index 3dae58090..46d253b84 100644 --- a/apps/code/src/main/services/handoff/service.ts +++ b/apps/code/src/main/services/handoff/service.ts @@ -3,10 +3,15 @@ import { join } from "node:path"; import { MAIN_TOKENS } from "@main/di/tokens"; import { logger } from "@main/utils/logger"; import { TypedEventEmitter } from "@main/utils/typed-event-emitter"; +import { HandoffCheckpointTracker } from "@posthog/agent/handoff-checkpoint"; import { PostHogAPIClient } from "@posthog/agent/posthog-api"; import { TreeTracker } from "@posthog/agent/tree-tracker"; -import type { TreeSnapshotEvent } from "@posthog/agent/types"; -import { app, net } from "electron"; +import type * as AgentTypes from "@posthog/agent/types"; +import { + type GitHandoffBranchDivergence, + readHandoffLocalGitState, +} from "@posthog/git/handoff"; +import { app, dialog, net } from "electron"; import { inject, injectable } from "inversify"; import type { IWorkspaceRepository } from "../../db/repositories/workspace-repository"; import type { AgentAuthAdapter } from "../agent/auth-adapter"; @@ -24,6 +29,7 @@ import { } from "./schemas"; const log = logger.scope("handoff"); +const CONTINUE_DIVERGENCE_BUTTON = 1; @injectable() export class HandoffService extends TypedEventEmitter { @@ -47,9 +53,11 @@ export class HandoffService extends TypedEventEmitter { const { repoPath } = input; let localTreeDirty = false; + let localGitState: AgentTypes.HandoffLocalGitState | undefined; try { const changedFiles = await this.gitService.getChangedFilesHead(repoPath); localTreeDirty = changedFiles.length > 0; + localGitState = await this.getLocalGitState(repoPath); } catch (err) { log.warn("Failed to check local working tree", { repoPath, err }); } @@ -59,7 +67,7 @@ export class HandoffService extends TypedEventEmitter { ? "Local working tree has uncommitted changes. Commit or stash them first." : undefined; - return { canHandoff, reason, localTreeDirty }; + return { canHandoff, reason, localTreeDirty, localGitState }; } async execute(input: HandoffExecuteInput): Promise { @@ -73,7 +81,7 @@ export class HandoffService extends TypedEventEmitter { }, applyTreeSnapshot: async ( - snapshot: TreeSnapshotEvent, + snapshot: AgentTypes.TreeSnapshotEvent, repoPath: string, taskId: string, runId: string, @@ -91,13 +99,35 @@ export class HandoffService extends TypedEventEmitter { }); }, - closeCloudRun: async (taskId, runId, apiHost, teamId) => { + applyGitCheckpoint: async ( + checkpoint: AgentTypes.GitCheckpointEvent, + repoPath: string, + taskId: string, + runId: string, + apiClient: PostHogAPIClient, + localGitState?: AgentTypes.HandoffLocalGitState, + ) => { + const tracker = new HandoffCheckpointTracker({ + repositoryPath: repoPath, + taskId, + runId, + apiClient, + }); + await tracker.applyFromHandoff(checkpoint, { + localGitState, + onDivergedBranch: (divergence) => + this.confirmDivergedBranchReset(divergence), + }); + }, + + closeCloudRun: async (taskId, runId, apiHost, teamId, localGitState) => { const result = await this.cloudTaskService.sendCommand({ taskId, runId, apiHost, teamId, method: "close", + params: localGitState ? { localGitState } : undefined, }); if (!result.success) { log.warn("Close command failed, continuing with handoff", { @@ -176,4 +206,37 @@ export class HandoffService extends TypedEventEmitter { sessionId: result.data.sessionId, }; } + + private async getLocalGitState( + repoPath: string, + ): Promise { + return readHandoffLocalGitState(repoPath); + } + + private async confirmDivergedBranchReset( + divergence: GitHandoffBranchDivergence, + ): Promise { + if (typeof app.isReady === "function" && !app.isReady()) { + log.warn( + "Cannot show divergence confirmation dialog before app is ready", + { + branch: divergence.branch, + }, + ); + return false; + } + + const result = await dialog.showMessageBox({ + type: "warning", + buttons: ["Cancel", "Continue"], + defaultId: 0, + cancelId: 0, + title: "Local branch has diverged", + message: `The local branch '${divergence.branch}' has commits that are not in the cloud handoff.`, + detail: + `Continuing will reset '${divergence.branch}' from ${divergence.localHead.slice(0, 7)} to ${divergence.cloudHead.slice(0, 7)}.\n\n` + + "Cancel if you want to keep the current local branch tip.", + }); + return result.response === CONTINUE_DIVERGENCE_BUTTON; + } } diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index 0bbfb5d48..1e5e2c35c 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -2045,10 +2045,21 @@ export class SessionService { sessionStoreSetters.updateSession(runId, { handoffInProgress: true }); try { - await this.runHandoffPreflight(taskId, runId, repoPath, auth); + const preflight = await this.runHandoffPreflight( + taskId, + runId, + repoPath, + auth, + ); this.stopCloudTaskWatch(taskId); sessionStoreSetters.updateSession(runId, { status: "connecting" }); - await this.executeHandoff(taskId, runId, repoPath, auth); + await this.executeHandoff( + taskId, + runId, + repoPath, + auth, + preflight.localGitState, + ); this.transitionToLocalSession(runId); this.subscribeToChannel(runId); queryClient.invalidateQueries({ queryKey: ["tasks"] }); @@ -2094,7 +2105,7 @@ export class SessionService { runId: string, repoPath: string, auth: { apiHost: string; projectId: number }, - ): Promise { + ): Promise>> { const preflight = await trpcClient.handoff.preflight.query({ taskId, runId, @@ -2108,6 +2119,7 @@ export class SessionService { }); throw new Error(preflight.reason ?? "Cannot hand off to local"); } + return preflight; } private async executeHandoff( @@ -2115,6 +2127,9 @@ export class SessionService { runId: string, repoPath: string, auth: { apiHost: string; projectId: number }, + localGitState?: Awaited< + ReturnType + >["localGitState"], ): Promise { const result = await trpcClient.handoff.execute.mutate({ taskId, @@ -2122,6 +2137,7 @@ export class SessionService { repoPath, apiHost: auth.apiHost, teamId: auth.projectId, + localGitState, }); if (!result.success) { throw new Error(result.error ?? "Handoff failed"); diff --git a/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts b/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts index f508afef6..1e5f214ff 100644 --- a/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts +++ b/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts @@ -70,7 +70,8 @@ function prepareTaskInput( ? buildCloudTaskDescription(serializedContent, filePaths) : undefined, filePaths, - repoPath: options.selectedDirectory, + repoPath: + options.workspaceMode === "cloud" ? undefined : options.selectedDirectory, repository: options.selectedRepository, githubIntegrationId: options.githubIntegrationId, workspaceMode: options.workspaceMode, diff --git a/packages/agent/package.json b/packages/agent/package.json index 28fd0e5d4..83f7a47bb 100644 --- a/packages/agent/package.json +++ b/packages/agent/package.json @@ -56,6 +56,10 @@ "types": "./dist/resume.d.ts", "import": "./dist/resume.js" }, + "./handoff-checkpoint": { + "types": "./dist/handoff-checkpoint.d.ts", + "import": "./dist/handoff-checkpoint.js" + }, "./tree-tracker": { "types": "./dist/tree-tracker.d.ts", "import": "./dist/tree-tracker.js" @@ -63,6 +67,10 @@ "./server": { "types": "./dist/server/agent-server.d.ts", "import": "./dist/server/agent-server.js" + }, + "./server/schemas": { + "types": "./dist/server/schemas.d.ts", + "import": "./dist/server/schemas.js" } }, "bin": { diff --git a/packages/agent/src/acp-extensions.ts b/packages/agent/src/acp-extensions.ts index 62a2a1083..9dc518ef6 100644 --- a/packages/agent/src/acp-extensions.ts +++ b/packages/agent/src/acp-extensions.ts @@ -37,6 +37,9 @@ export const POSTHOG_NOTIFICATIONS = { /** Tree state snapshot captured (git tree hash + file archive) */ TREE_SNAPSHOT: "_posthog/tree_snapshot", + /** Git checkpoint captured for handoff */ + GIT_CHECKPOINT: "_posthog/git_checkpoint", + /** Agent mode changed (interactive/background) */ MODE_CHANGE: "_posthog/mode_change", diff --git a/packages/agent/src/handoff-checkpoint.test.ts b/packages/agent/src/handoff-checkpoint.test.ts new file mode 100644 index 000000000..78749ff07 --- /dev/null +++ b/packages/agent/src/handoff-checkpoint.test.ts @@ -0,0 +1,183 @@ +import { afterEach, describe, expect, it } from "vitest"; +import { HandoffCheckpointTracker } from "./handoff-checkpoint"; +import { + cloneTestRepo, + createTestRepo, + type TestRepo, +} from "./sagas/test-fixtures"; +import type { HandoffLocalGitState } from "./types"; + +interface BundleStore { + artifacts: Record; + storagePath: string; + manifest: Array<{ storage_path: string }>; +} + +interface HandoffRepos { + cloudRepo: TestRepo; + localRepo: TestRepo; + branch: string; + localGitState: HandoffLocalGitState; +} + +const WORKTREE_FILES = ["tracked.txt", "unstaged.txt", "untracked.txt"]; + +function createMockApi(store: BundleStore) { + return { + uploadTaskArtifacts: async ( + _taskId: string, + _runId: string, + artifacts: Array<{ + name: string; + content: string; + }>, + ) => { + const uploaded = artifacts.map((artifact, index) => { + const storagePath = `${store.storagePath}-${store.manifest.length + index}-${artifact.name}`; + store.artifacts[storagePath] = artifact.content; + return { storage_path: storagePath }; + }); + for (const entry of uploaded) { + store.manifest.push(entry); + } + return store.manifest; + }, + downloadArtifact: async ( + _taskId: string, + _runId: string, + artifactPath: string, + ) => { + const contentBase64 = store.artifacts[artifactPath]; + if (!contentBase64) return null; + const buffer = Buffer.from(contentBase64, "utf-8"); + return buffer.buffer.slice( + buffer.byteOffset, + buffer.byteOffset + buffer.byteLength, + ); + }, + }; +} + +function createBundleStore(): BundleStore { + return { + storagePath: "gs://bucket/handoff", + artifacts: {}, + manifest: [ + { + storage_path: "gs://bucket/handoff-0-existing-tree_snapshot.tar.gz", + }, + ], + }; +} + +function createTracker( + repositoryPath: string, + apiClient: ReturnType, +) { + return new HandoffCheckpointTracker({ + repositoryPath, + taskId: "task-1", + runId: "run-1", + apiClient: apiClient as never, + }); +} + +async function seedCloudRepo(repo: TestRepo): Promise { + await repo.writeFile("tracked.txt", "base\n"); + await repo.writeFile("unstaged.txt", "base unstaged\n"); + await repo.git(["add", "tracked.txt", "unstaged.txt"]); + await repo.git(["commit", "-m", "Add tracked files"]); +} + +async function prepareHandoffRepos( + cleanups: Array<() => Promise>, +): Promise { + const cloudRepo = await createTestRepo("handoff-cloud"); + cleanups.push(cloudRepo.cleanup); + await seedCloudRepo(cloudRepo); + + const localRepo = await cloneTestRepo(cloudRepo.path, "handoff-local"); + cleanups.push(localRepo.cleanup); + + const branch = await cloudRepo.git(["rev-parse", "--abbrev-ref", "HEAD"]); + const localHead = await localRepo.git(["rev-parse", "HEAD"]); + const upstreamHead = await localRepo.git(["rev-parse", `origin/${branch}`]); + + return { + cloudRepo, + localRepo, + branch, + localGitState: { + head: localHead, + branch, + upstreamHead, + upstreamRemote: "origin", + upstreamMergeRef: `refs/heads/${branch}`, + }, + }; +} + +async function makeCloudChanges(repo: TestRepo): Promise { + await repo.writeFile("committed.txt", "cloud commit\n"); + await repo.git(["add", "committed.txt"]); + await repo.git(["commit", "-m", "Cloud commit"]); + + await repo.writeFile("tracked.txt", "staged change\n"); + await repo.git(["add", "tracked.txt"]); + await repo.writeFile("unstaged.txt", "unstaged change\n"); + await repo.writeFile("untracked.txt", "untracked\n"); +} + +async function mirrorRestoredWorktree( + cloudRepo: TestRepo, + localRepo: TestRepo, +): Promise { + for (const file of WORKTREE_FILES) { + await localRepo.writeFile(file, await cloudRepo.readFile(file)); + } +} + +describe("HandoffCheckpointTracker", () => { + const cleanups: Array<() => Promise> = []; + + afterEach(async () => { + await Promise.all(cleanups.splice(0).map((cleanup) => cleanup())); + }); + + it("restores head commit and index state for handoff replay", async () => { + const { cloudRepo, localRepo, branch, localGitState } = + await prepareHandoffRepos(cleanups); + await makeCloudChanges(cloudRepo); + + const store = createBundleStore(); + const apiClient = createMockApi(store); + const captureTracker = createTracker(cloudRepo.path, apiClient); + + const checkpoint = await captureTracker.captureForHandoff(localGitState); + + expect(checkpoint).not.toBeNull(); + if (!checkpoint) return; + expect(Object.keys(store.artifacts).length).toBeGreaterThan(0); + + const applyTracker = createTracker(localRepo.path, apiClient); + await applyTracker.applyFromHandoff(checkpoint); + + // The handoff service restores files separately via tree_snapshot. + // Mirror that here so the restored git index can be validated. + await mirrorRestoredWorktree(cloudRepo, localRepo); + + expect(await localRepo.git(["rev-parse", "HEAD"])).toBe(checkpoint.head); + expect(await localRepo.git(["rev-parse", "--abbrev-ref", "HEAD"])).toBe( + branch, + ); + expect(await localRepo.readFile("committed.txt")).toBe("cloud commit\n"); + expect(await localRepo.readFile("tracked.txt")).toBe("staged change\n"); + expect(await localRepo.readFile("unstaged.txt")).toBe("unstaged change\n"); + expect(await localRepo.readFile("untracked.txt")).toBe("untracked\n"); + + const status = await localRepo.git(["status", "--porcelain"]); + expect(status).toContain("M tracked.txt"); + expect(status).toContain(" M unstaged.txt"); + expect(status).toContain("?? untracked.txt"); + }); +}); diff --git a/packages/agent/src/handoff-checkpoint.ts b/packages/agent/src/handoff-checkpoint.ts new file mode 100644 index 000000000..03ea9ad0c --- /dev/null +++ b/packages/agent/src/handoff-checkpoint.ts @@ -0,0 +1,361 @@ +import { mkdir, readFile, rm, writeFile } from "node:fs/promises"; +import { join } from "node:path"; +import { + type GitHandoffBranchDivergence, + type GitHandoffCheckpoint, + GitHandoffTracker, +} from "@posthog/git/handoff"; +import type { PostHogAPIClient } from "./posthog-api"; +import type { GitCheckpoint, HandoffLocalGitState } from "./types"; +import { Logger } from "./utils/logger"; + +export interface HandoffCheckpointTrackerConfig { + repositoryPath: string; + taskId: string; + runId: string; + apiClient?: PostHogAPIClient; + logger?: Logger; +} + +type ArtifactTransfer = T & { + rawBytes: number; + wireBytes: number; +}; + +type UploadedArtifact = ArtifactTransfer<{ storagePath?: string }>; +type DownloadedArtifact = ArtifactTransfer<{ filePath: string }>; + +type ArtifactKey = "pack" | "index"; +type ArtifactSlotMap = Partial< + Record> +>; + +interface UploadArtifactSpec { + key: ArtifactKey; + filePath?: string; + name: string; + contentType: string; +} + +interface DownloadArtifactSpec { + key: ArtifactKey; + storagePath?: string; + filePath: string; + label: string; +} + +type Uploads = ArtifactSlotMap<{ storagePath?: string }>; +type Downloads = ArtifactSlotMap<{ filePath: string }>; + +export class HandoffCheckpointTracker { + private repositoryPath: string; + private taskId: string; + private runId: string; + private apiClient?: PostHogAPIClient; + private logger: Logger; + + constructor(config: HandoffCheckpointTrackerConfig) { + this.repositoryPath = config.repositoryPath; + this.taskId = config.taskId; + this.runId = config.runId; + this.apiClient = config.apiClient; + this.logger = + config.logger || + new Logger({ debug: false, prefix: "[HandoffCheckpointTracker]" }); + } + + async captureForHandoff( + localGitState?: HandoffLocalGitState, + ): Promise { + if (!this.apiClient) { + throw new Error( + "Cannot capture handoff checkpoint: API client not configured", + ); + } + + const gitTracker = this.createGitTracker(); + const capture = await gitTracker.captureForHandoff(localGitState); + + try { + const uploads = await this.uploadArtifacts([ + { + key: "pack", + filePath: capture.headPack?.path, + name: `handoff/${capture.checkpoint.checkpointId}.pack`, + contentType: "application/x-git-packed-objects", + }, + { + key: "index", + filePath: capture.indexFile.path, + name: `handoff/${capture.checkpoint.checkpointId}.index`, + contentType: "application/octet-stream", + }, + ]); + + this.logCaptureMetrics(capture.checkpoint, uploads); + + return { + ...capture.checkpoint, + artifactPath: uploads.pack?.storagePath, + indexArtifactPath: uploads.index?.storagePath, + }; + } finally { + await this.removeIfPresent(capture.headPack?.path); + await this.removeIfPresent(capture.indexFile.path); + } + } + + async applyFromHandoff( + checkpoint: GitCheckpoint, + options?: { + localGitState?: HandoffLocalGitState; + onDivergedBranch?: ( + divergence: GitHandoffBranchDivergence, + ) => Promise; + }, + ): Promise { + if (!this.apiClient) { + throw new Error( + "Cannot apply handoff checkpoint: API client not configured", + ); + } + + const gitTracker = this.createGitTracker(); + const tmpDir = join(this.repositoryPath, ".posthog", "tmp"); + await mkdir(tmpDir, { recursive: true }); + + const packPath = join(tmpDir, `${checkpoint.checkpointId}.pack`); + const indexPath = join(tmpDir, `${checkpoint.checkpointId}.index`); + + try { + const downloads = await this.downloadArtifacts([ + { + key: "pack", + storagePath: checkpoint.artifactPath, + filePath: packPath, + label: "handoff pack", + }, + { + key: "index", + storagePath: checkpoint.indexArtifactPath, + filePath: indexPath, + label: "handoff index", + }, + ]); + + const applyResult = await gitTracker.applyFromHandoff({ + checkpoint: this.toGitCheckpoint(checkpoint), + headPackPath: downloads.pack?.filePath, + indexPath: downloads.index?.filePath, + localGitState: options?.localGitState, + onDivergedBranch: options?.onDivergedBranch, + }); + + this.logApplyMetrics(checkpoint, downloads, applyResult.totalBytes); + } finally { + await this.removeIfPresent(packPath); + await this.removeIfPresent(indexPath); + } + } + + private toGitCheckpoint(checkpoint: GitCheckpoint): GitHandoffCheckpoint { + return { + checkpointId: checkpoint.checkpointId, + commit: checkpoint.commit, + checkpointRef: checkpoint.checkpointRef, + headRef: checkpoint.headRef, + head: checkpoint.head, + branch: checkpoint.branch, + indexTree: checkpoint.indexTree, + worktreeTree: checkpoint.worktreeTree, + timestamp: checkpoint.timestamp, + upstreamRemote: checkpoint.upstreamRemote ?? null, + upstreamMergeRef: checkpoint.upstreamMergeRef ?? null, + remoteUrl: checkpoint.remoteUrl ?? null, + }; + } + + private async uploadArtifactFile( + filePath: string, + name: string, + contentType: string, + ): Promise { + if (!this.apiClient) { + return { rawBytes: 0, wireBytes: 0 }; + } + + const content = await readFile(filePath); + const base64Content = content.toString("base64"); + const artifacts = await this.apiClient.uploadTaskArtifacts( + this.taskId, + this.runId, + [ + { + name, + type: "artifact", + content: base64Content, + content_type: contentType, + }, + ], + ); + + return { + storagePath: artifacts.at(-1)?.storage_path, + rawBytes: content.byteLength, + wireBytes: Buffer.byteLength(base64Content, "utf-8"), + }; + } + + private async uploadArtifacts(specs: UploadArtifactSpec[]): Promise { + const uploads = await Promise.all( + specs.map(async (spec) => { + if (!spec.filePath) { + return [spec.key, undefined] as const; + } + return [ + spec.key, + await this.uploadArtifactFile( + spec.filePath, + spec.name, + spec.contentType, + ), + ] as const; + }), + ); + + return Object.fromEntries(uploads) as Uploads; + } + + private async downloadArtifactToFile( + artifactPath: string, + filePath: string, + label: string, + ): Promise { + if (!this.apiClient) { + throw new Error(`Cannot download ${label}: API client not configured`); + } + + const arrayBuffer = await this.apiClient.downloadArtifact( + this.taskId, + this.runId, + artifactPath, + ); + if (!arrayBuffer) { + throw new Error(`Failed to download ${label}`); + } + + const base64Content = Buffer.from(arrayBuffer).toString("utf-8"); + const binaryContent = Buffer.from(base64Content, "base64"); + await writeFile(filePath, binaryContent); + return { + filePath, + rawBytes: binaryContent.byteLength, + wireBytes: arrayBuffer.byteLength, + }; + } + + private async downloadArtifacts( + specs: DownloadArtifactSpec[], + ): Promise { + const downloads = await Promise.all( + specs.map(async (spec) => { + if (!spec.storagePath) { + return [spec.key, undefined] as const; + } + return [ + spec.key, + await this.downloadArtifactToFile( + spec.storagePath, + spec.filePath, + spec.label, + ), + ] as const; + }), + ); + + return Object.fromEntries(downloads) as Downloads; + } + + private createGitTracker(): GitHandoffTracker { + return new GitHandoffTracker({ + repositoryPath: this.repositoryPath, + logger: this.logger, + }); + } + + private logCaptureMetrics( + checkpoint: GitHandoffCheckpoint, + uploads: Uploads, + ): void { + this.logger.info("Captured handoff checkpoint", { + checkpointId: checkpoint.checkpointId, + branch: checkpoint.branch, + head: checkpoint.head, + artifactPath: uploads.pack?.storagePath, + indexArtifactPath: uploads.index?.storagePath, + ...this.buildMetricPayload(uploads), + }); + } + + private logApplyMetrics( + checkpoint: GitCheckpoint, + downloads: Downloads, + totalBytes: number, + ): void { + this.logger.info("Applied handoff checkpoint", { + checkpointId: checkpoint.checkpointId, + commit: checkpoint.commit, + branch: checkpoint.branch, + head: checkpoint.head, + packBytes: downloads.pack?.rawBytes ?? 0, + packWireBytes: downloads.pack?.wireBytes ?? 0, + indexBytes: downloads.index?.rawBytes ?? 0, + indexWireBytes: downloads.index?.wireBytes ?? 0, + totalBytes, + totalWireBytes: this.sumWireBytes(downloads.pack, downloads.index), + }); + } + + private buildMetricPayload(metrics: ArtifactSlotMap): { + packBytes: number; + packWireBytes: number; + indexBytes: number; + indexWireBytes: number; + totalBytes: number; + totalWireBytes: number; + } { + return { + packBytes: metrics.pack?.rawBytes ?? 0, + packWireBytes: metrics.pack?.wireBytes ?? 0, + indexBytes: metrics.index?.rawBytes ?? 0, + indexWireBytes: metrics.index?.wireBytes ?? 0, + totalBytes: this.sumRawBytes(metrics.pack, metrics.index), + totalWireBytes: this.sumWireBytes(metrics.pack, metrics.index), + }; + } + + private sumRawBytes( + ...artifacts: Array<{ rawBytes: number } | undefined> + ): number { + return artifacts.reduce( + (total, artifact) => total + (artifact?.rawBytes ?? 0), + 0, + ); + } + + private sumWireBytes( + ...artifacts: Array<{ wireBytes: number } | undefined> + ): number { + return artifacts.reduce( + (total, artifact) => total + (artifact?.wireBytes ?? 0), + 0, + ); + } + + private async removeIfPresent(filePath: string | undefined): Promise { + if (!filePath) { + return; + } + await rm(filePath, { force: true }).catch(() => {}); + } +} diff --git a/packages/agent/src/posthog-api.test.ts b/packages/agent/src/posthog-api.test.ts index 4ab3f3bd2..d09893331 100644 --- a/packages/agent/src/posthog-api.test.ts +++ b/packages/agent/src/posthog-api.test.ts @@ -45,4 +45,33 @@ describe("PostHogAPIClient", () => { expect(refreshApiKey).toHaveBeenCalledTimes(1); expect(mockFetch).toHaveBeenCalledTimes(2); }); + + it("returns only the artifacts created by the current upload request", async () => { + const client = new PostHogAPIClient({ + apiUrl: "https://app.posthog.com", + getApiKey: vi.fn().mockResolvedValue("token"), + projectId: 1, + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: vi.fn().mockResolvedValue({ + artifacts: [ + { storage_path: "gs://bucket/existing.tar.gz", name: "existing" }, + { storage_path: "gs://bucket/new-1.pack", name: "new-1" }, + { storage_path: "gs://bucket/new-2.index", name: "new-2" }, + ], + }), + }); + + const artifacts = await client.uploadTaskArtifacts("task-1", "run-1", [ + { name: "new-1", type: "artifact", content: "AAA" }, + { name: "new-2", type: "artifact", content: "BBB" }, + ]); + + expect(artifacts).toEqual([ + { storage_path: "gs://bucket/new-1.pack", name: "new-1" }, + { storage_path: "gs://bucket/new-2.index", name: "new-2" }, + ]); + }); }); diff --git a/packages/agent/src/posthog-api.ts b/packages/agent/src/posthog-api.ts index c9d4b3e4f..f50578e87 100644 --- a/packages/agent/src/posthog-api.ts +++ b/packages/agent/src/posthog-api.ts @@ -206,7 +206,11 @@ export class PostHogAPIClient { }, ); - return response.artifacts ?? []; + const manifest = response.artifacts ?? []; + + // The backend returns the full run artifact manifest after each upload. + // Callers want the artifacts corresponding to this upload request only. + return manifest.slice(-artifacts.length); } async getArtifactPresignedUrl( diff --git a/packages/agent/src/resume.ts b/packages/agent/src/resume.ts index cd7f8cd0f..84dcf5103 100644 --- a/packages/agent/src/resume.ts +++ b/packages/agent/src/resume.ts @@ -19,12 +19,17 @@ import type { ContentBlock } from "@agentclientprotocol/sdk"; import { selectRecentTurns } from "./adapters/claude/session/jsonl-hydration"; import type { PostHogAPIClient } from "./posthog-api"; import { ResumeSaga } from "./sagas/resume-saga"; -import type { DeviceInfo, TreeSnapshotEvent } from "./types"; +import type { + DeviceInfo, + GitCheckpointEvent, + TreeSnapshotEvent, +} from "./types"; import { Logger } from "./utils/logger"; export interface ResumeState { conversation: ConversationTurn[]; latestSnapshot: TreeSnapshotEvent | null; + latestGitCheckpoint: GitCheckpointEvent | null; /** Whether the tree snapshot was successfully applied (files restored) */ snapshotApplied: boolean; interrupted: boolean; @@ -96,6 +101,7 @@ export async function resumeFromLog( return { conversation: result.data.conversation as ConversationTurn[], latestSnapshot: result.data.latestSnapshot, + latestGitCheckpoint: result.data.latestGitCheckpoint, snapshotApplied: result.data.snapshotApplied, interrupted: result.data.interrupted, lastDevice: result.data.lastDevice, diff --git a/packages/agent/src/sagas/apply-snapshot-saga.ts b/packages/agent/src/sagas/apply-snapshot-saga.ts index 8e4c89402..01a942c30 100644 --- a/packages/agent/src/sagas/apply-snapshot-saga.ts +++ b/packages/agent/src/sagas/apply-snapshot-saga.ts @@ -59,6 +59,13 @@ export class ApplySnapshotSaga extends Saga< const base64Content = Buffer.from(arrayBuffer).toString("utf-8"); const binaryContent = Buffer.from(base64Content, "base64"); await writeFile(archivePath, binaryContent); + this.log.info("Tree archive downloaded", { + treeHash: snapshot.treeHash, + snapshotBytes: binaryContent.byteLength, + snapshotWireBytes: arrayBuffer.byteLength, + totalBytes: binaryContent.byteLength, + totalWireBytes: arrayBuffer.byteLength, + }); }, rollback: async () => { if (this.archivePath) { diff --git a/packages/agent/src/sagas/capture-tree-saga.ts b/packages/agent/src/sagas/capture-tree-saga.ts index 1852535a8..851082637 100644 --- a/packages/agent/src/sagas/capture-tree-saga.ts +++ b/packages/agent/src/sagas/capture-tree-saga.ts @@ -113,6 +113,8 @@ export class CaptureTreeSaga extends Saga { execute: async () => { const archiveContent = await readFile(archivePath); const base64Content = archiveContent.toString("base64"); + const snapshotBytes = archiveContent.byteLength; + const snapshotWireBytes = Buffer.byteLength(base64Content, "utf-8"); const artifacts = await apiClient.uploadTaskArtifacts(taskId, runId, [ { @@ -123,12 +125,17 @@ export class CaptureTreeSaga extends Saga { }, ]); - if (artifacts.length > 0 && artifacts[0].storage_path) { + const uploadedArtifact = artifacts[0]; + if (uploadedArtifact?.storage_path) { this.log.info("Tree archive uploaded", { - storagePath: artifacts[0].storage_path, + storagePath: uploadedArtifact.storage_path, treeHash, + snapshotBytes, + snapshotWireBytes, + totalBytes: snapshotBytes, + totalWireBytes: snapshotWireBytes, }); - return artifacts[0].storage_path; + return uploadedArtifact.storage_path; } return undefined; diff --git a/packages/agent/src/sagas/resume-saga.ts b/packages/agent/src/sagas/resume-saga.ts index c363020f2..d8658b19e 100644 --- a/packages/agent/src/sagas/resume-saga.ts +++ b/packages/agent/src/sagas/resume-saga.ts @@ -5,6 +5,7 @@ import type { PostHogAPIClient } from "../posthog-api"; import { TreeTracker } from "../tree-tracker"; import type { DeviceInfo, + GitCheckpointEvent, StoredNotification, TreeSnapshotEvent, } from "../types"; @@ -34,6 +35,7 @@ export interface ResumeInput { export interface ResumeOutput { conversation: ConversationTurn[]; latestSnapshot: TreeSnapshotEvent | null; + latestGitCheckpoint: GitCheckpointEvent | null; snapshotApplied: boolean; interrupted: boolean; lastDevice?: DeviceInfo; @@ -75,6 +77,11 @@ export class ResumeSaga extends Saga { Promise.resolve(this.findLatestTreeSnapshot(entries)), ); + const latestGitCheckpoint = await this.readOnlyStep( + "find_git_checkpoint", + () => Promise.resolve(this.findLatestGitCheckpoint(entries)), + ); + // Step 4: Apply snapshot if present (wrapped in step for consistent logging) // Note: We use a try/catch inside the step because snapshot failure should NOT fail the saga let snapshotApplied = false; @@ -158,6 +165,7 @@ export class ResumeSaga extends Saga { return { conversation, latestSnapshot, + latestGitCheckpoint, snapshotApplied, interrupted: latestSnapshot?.interrupted ?? false, lastDevice, @@ -169,6 +177,7 @@ export class ResumeSaga extends Saga { return { conversation: [], latestSnapshot: null, + latestGitCheckpoint: null, snapshotApplied: false, interrupted: false, logEntryCount: 0, @@ -197,6 +206,29 @@ export class ResumeSaga extends Saga { return null; } + private findLatestGitCheckpoint( + entries: StoredNotification[], + ): GitCheckpointEvent | null { + const sdkPrefixedMethod = `_${POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT}`; + + for (let i = entries.length - 1; i >= 0; i--) { + const entry = entries[i]; + const method = entry.notification?.method; + if ( + method === sdkPrefixedMethod || + method === POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT + ) { + const params = entry.notification?.params as + | GitCheckpointEvent + | undefined; + if (params?.checkpointId && params?.checkpointRef) { + return params; + } + } + } + return null; + } + private findLastDeviceInfo( entries: StoredNotification[], ): DeviceInfo | undefined { diff --git a/packages/agent/src/sagas/test-fixtures.ts b/packages/agent/src/sagas/test-fixtures.ts index f569f5d68..1cf56a56d 100644 --- a/packages/agent/src/sagas/test-fixtures.ts +++ b/packages/agent/src/sagas/test-fixtures.ts @@ -67,6 +67,52 @@ export async function createTestRepo(prefix = "test-repo"): Promise { }; } +export async function cloneTestRepo( + sourcePath: string, + prefix = "test-repo-clone", +): Promise { + const clonePath = join( + tmpdir(), + `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2)}`, + ); + await execFileAsync("git", ["clone", sourcePath, clonePath]); + await execFileAsync("git", ["config", "user.email", "test@test.com"], { + cwd: clonePath, + }); + await execFileAsync("git", ["config", "user.name", "Test"], { + cwd: clonePath, + }); + await execFileAsync("git", ["config", "commit.gpgsign", "false"], { + cwd: clonePath, + }); + + const git = async (args: string[]): Promise => { + const { stdout } = await execFileAsync("git", args, { cwd: clonePath }); + return stdout.trim(); + }; + + return { + path: clonePath, + cleanup: () => rm(clonePath, { recursive: true, force: true }), + git, + writeFile: async (relativePath: string, content: string) => { + const fullPath = join(clonePath, relativePath); + const dir = join(fullPath, ".."); + await mkdir(dir, { recursive: true }); + await writeFile(fullPath, content); + }, + readFile: async (relativePath: string) => { + return readFile(join(clonePath, relativePath), "utf-8"); + }, + deleteFile: async (relativePath: string) => { + await rm(join(clonePath, relativePath), { force: true }); + }, + exists: (relativePath: string) => { + return existsSync(join(clonePath, relativePath)); + }, + }; +} + export function createMockLogger(): SagaLogger { return { info: vi.fn(), diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index 426d9da29..3258e89cc 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -13,6 +13,7 @@ import { createAcpConnection, type InProcessAcpConnection, } from "../adapters/acp-connection"; +import { HandoffCheckpointTracker } from "../handoff-checkpoint"; import { PostHogAPIClient } from "../posthog-api"; import { formatConversationForResume, @@ -24,6 +25,8 @@ import { TreeTracker } from "../tree-tracker"; import type { AgentMode, DeviceInfo, + GitCheckpointEvent, + HandoffLocalGitState, LogLevel, TaskRun, TreeSnapshotEvent, @@ -37,7 +40,11 @@ import { promptBlocksToText, } from "./cloud-prompt"; import { type JwtPayload, JwtValidationError, validateJwt } from "./jwt"; -import { jsonRpcRequestSchema, validateCommandParams } from "./schemas"; +import { + handoffLocalGitStateSchema, + jsonRpcRequestSchema, + validateCommandParams, +} from "./schemas"; import type { AgentServerConfig } from "./types"; type MessageCallback = (message: unknown) => void; @@ -156,6 +163,7 @@ interface ActiveSession { sseController: SseController | null; deviceInfo: DeviceInfo; logWriter: SessionLogWriter; + pendingHandoffGitState?: HandoffLocalGitState; } export class AgentServer { @@ -574,6 +582,10 @@ export class AgentServer { case POSTHOG_NOTIFICATIONS.CLOSE: case "close": { this.logger.info("Close requested"); + const localGitState = this.extractHandoffLocalGitState(params); + if (localGitState && this.session) { + this.session.pendingHandoffGitState = localGitState; + } await this.cleanupSession(); return { closed: true }; } @@ -756,6 +768,7 @@ export class AgentServer { sseController, deviceInfo, logWriter, + pendingHandoffGitState: undefined, }; this.logger = new Logger({ @@ -1529,6 +1542,12 @@ ${attributionInstructions} this.logger.info("Cleaning up session"); + try { + await this.captureHandoffCheckpoint(); + } catch (error) { + this.logger.error("Failed to capture handoff checkpoint", error); + } + try { await this.captureTreeState(); } catch (error) { @@ -1595,6 +1614,60 @@ ${attributionInstructions} } } + private async captureHandoffCheckpoint(): Promise { + if (!this.session?.treeTracker || !this.session.pendingHandoffGitState) { + return; + } + if (!this.posthogAPI) { + this.logger.warn( + "Skipping handoff checkpoint capture: PostHog API client is not configured", + ); + return; + } + + const tracker = new HandoffCheckpointTracker({ + repositoryPath: this.config.repositoryPath ?? "/tmp/workspace", + taskId: this.session.payload.task_id, + runId: this.session.payload.run_id, + apiClient: this.posthogAPI, + logger: this.logger.child("HandoffCheckpoint"), + }); + + const checkpoint = await tracker.captureForHandoff( + this.session.pendingHandoffGitState, + ); + if (!checkpoint) return; + + const checkpointWithDevice: GitCheckpointEvent = { + ...checkpoint, + device: this.session.deviceInfo, + }; + + const notification = { + jsonrpc: "2.0" as const, + method: POSTHOG_NOTIFICATIONS.GIT_CHECKPOINT, + params: checkpointWithDevice, + }; + + this.broadcastEvent({ + type: "notification", + timestamp: new Date().toISOString(), + notification, + }); + + this.session.logWriter.appendRawLine( + this.session.payload.run_id, + JSON.stringify(notification), + ); + } + + private extractHandoffLocalGitState( + params: Record, + ): HandoffLocalGitState | null { + const result = handoffLocalGitStateSchema.safeParse(params.localGitState); + return result.success ? result.data : null; + } + private broadcastTurnComplete(stopReason: string): void { if (!this.session) return; this.broadcastEvent({ diff --git a/packages/agent/src/server/schemas.ts b/packages/agent/src/server/schemas.ts index 96e9cf022..4f7cfb9c2 100644 --- a/packages/agent/src/server/schemas.ts +++ b/packages/agent/src/server/schemas.ts @@ -5,6 +5,19 @@ const httpHeaderSchema = z.object({ value: z.string(), }); +const nullishString = z + .string() + .nullish() + .transform((value) => value ?? null); + +export const handoffLocalGitStateSchema = z.object({ + head: nullishString, + branch: nullishString, + upstreamHead: nullishString, + upstreamRemote: nullishString, + upstreamMergeRef: nullishString, +}); + const remoteMcpServerSchema = z.object({ type: z.enum(["http", "sse"]), name: z.string().min(1, "MCP server name is required"), @@ -48,13 +61,19 @@ export const userMessageParamsSchema = z.object({ ]), }); +export const closeParamsSchema = z + .object({ + localGitState: handoffLocalGitStateSchema.optional(), + }) + .optional(); + export const commandParamsSchemas = { user_message: userMessageParamsSchema, "posthog/user_message": userMessageParamsSchema, cancel: z.object({}).optional(), "posthog/cancel": z.object({}).optional(), - close: z.object({}).optional(), - "posthog/close": z.object({}).optional(), + close: closeParamsSchema, + "posthog/close": closeParamsSchema, } as const; export type CommandMethod = keyof typeof commandParamsSchemas; diff --git a/packages/agent/src/types.ts b/packages/agent/src/types.ts index 3cc039d6a..09c39e04b 100644 --- a/packages/agent/src/types.ts +++ b/packages/agent/src/types.ts @@ -1,3 +1,8 @@ +import type { + GitHandoffCheckpoint, + HandoffLocalGitState as GitHandoffLocalGitState, +} from "@posthog/git/handoff"; + /** * Stored custom notification following ACP extensibility model. * Custom notifications use underscore-prefixed methods (e.g., `_posthog/phase_start`). @@ -185,3 +190,22 @@ export interface TreeSnapshot { export interface TreeSnapshotEvent extends TreeSnapshot { device?: DeviceInfo; } + +export type HandoffLocalGitState = GitHandoffLocalGitState; + +export interface GitCheckpoint extends GitHandoffCheckpoint { + artifactPath?: string; + indexArtifactPath?: string; +} + +export interface GitCheckpointEvent extends GitCheckpoint { + device?: DeviceInfo; +} + +/** + * Keeps the emitted `@posthog/agent/types` entrypoint as a runtime ESM module. + * + * `export {}` is stripped by tsup in this package, which leaves `dist/types.js` + * empty and breaks downstream type resolution for the exported subpath. + */ +export const AGENT_TYPES_MODULE = true; diff --git a/packages/agent/tsup.config.ts b/packages/agent/tsup.config.ts index eb80ccad2..88ed0ceee 100644 --- a/packages/agent/tsup.config.ts +++ b/packages/agent/tsup.config.ts @@ -74,6 +74,7 @@ export default defineConfig([ "src/index.ts", "src/agent.ts", "src/gateway-models.ts", + "src/handoff-checkpoint.ts", "src/posthog-api.ts", "src/resume.ts", "src/tree-tracker.ts", @@ -85,6 +86,7 @@ export default defineConfig([ "src/adapters/claude/session/jsonl-hydration.ts", "src/adapters/claude/session/models.ts", "src/execution-mode.ts", + "src/server/schemas.ts", "src/server/agent-server.ts", ], format: ["esm"], diff --git a/packages/git/src/handoff.test.ts b/packages/git/src/handoff.test.ts new file mode 100644 index 000000000..99c5c5c02 --- /dev/null +++ b/packages/git/src/handoff.test.ts @@ -0,0 +1,391 @@ +import { execFile } from "node:child_process"; +import { mkdtemp, readFile, rm, writeFile } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import path from "node:path"; +import { promisify } from "node:util"; +import { describe, expect, it, vi } from "vitest"; +import { createGitClient } from "./client"; +import { + type GitHandoffApplyInput, + type GitHandoffCaptureResult, + GitHandoffTracker, + type HandoffLocalGitState, +} from "./handoff"; + +const execFileAsync = promisify(execFile); + +async function setupRepo(): Promise { + const dir = await mkdtemp(path.join(tmpdir(), "posthog-code-handoff-")); + const git = createGitClient(dir); + await git.init(); + await git.addConfig("user.name", "PostHog Code Test"); + await git.addConfig("user.email", "posthog-code-test@example.com"); + await git.addConfig("commit.gpgsign", "false"); + + await writeFile(path.join(dir, "tracked.txt"), "base\n"); + await writeFile(path.join(dir, "unstaged.txt"), "base unstaged\n"); + await git.add(["tracked.txt", "unstaged.txt"]); + await git.commit("initial"); + + return dir; +} + +async function cloneRepo(sourcePath: string): Promise { + const clonePath = await mkdtemp( + path.join(tmpdir(), "posthog-code-handoff-clone-"), + ); + await execFileAsync("git", ["clone", sourcePath, clonePath]); + await execFileAsync("git", ["config", "user.email", "test@test.com"], { + cwd: clonePath, + }); + await execFileAsync("git", ["config", "user.name", "Test"], { + cwd: clonePath, + }); + await execFileAsync("git", ["config", "commit.gpgsign", "false"], { + cwd: clonePath, + }); + return clonePath; +} + +interface RepoHarness { + cloudRepo: string; + localRepo: string; + branch: string; + cloudGit: ReturnType; + localGit: ReturnType; + localGitState: HandoffLocalGitState; +} + +async function withRepos( + fn: (repos: RepoHarness) => Promise, +): Promise { + const cloudRepo = await setupRepo(); + const localRepo = await cloneRepo(cloudRepo); + const cloudGit = createGitClient(cloudRepo); + const localGit = createGitClient(localRepo); + try { + const branch = (await cloudGit.revparse(["--abbrev-ref", "HEAD"])).trim(); + const localHead = (await localGit.revparse(["HEAD"])).trim(); + const upstreamHead = (await localGit.revparse([`origin/${branch}`])).trim(); + + return await fn({ + cloudRepo, + localRepo, + branch, + cloudGit, + localGit, + localGitState: { + head: localHead, + branch, + upstreamHead, + upstreamRemote: "origin", + upstreamMergeRef: `refs/heads/${branch}`, + }, + }); + } finally { + await rm(localRepo, { recursive: true, force: true }); + await rm(cloudRepo, { recursive: true, force: true }); + } +} + +async function makeCloudChanges( + cloudRepo: string, + cloudGit: ReturnType, +) { + await writeFile(path.join(cloudRepo, "committed.txt"), "cloud commit\n"); + await cloudGit.add(["committed.txt"]); + await cloudGit.commit("Cloud commit"); + + await writeFile(path.join(cloudRepo, "tracked.txt"), "staged change\n"); + await cloudGit.add(["tracked.txt"]); + await writeFile(path.join(cloudRepo, "unstaged.txt"), "unstaged change\n"); + await writeFile(path.join(cloudRepo, "untracked.txt"), "untracked\n"); +} + +async function mirrorWorktreeFiles( + fromRepo: string, + toRepo: string, + files: string[], +): Promise { + await Promise.all( + files.map(async (file) => { + await writeFile( + path.join(toRepo, file), + await readFile(path.join(fromRepo, file), "utf-8"), + ); + }), + ); +} + +async function cleanupCapture(capture: GitHandoffCaptureResult): Promise { + if (capture.headPack?.path) { + await rm(capture.headPack.path, { force: true }).catch(() => {}); + } + await rm(capture.indexFile.path, { force: true }).catch(() => {}); +} + +async function captureAndApply( + repos: RepoHarness, + options?: { + captureState?: HandoffLocalGitState; + applyState?: HandoffLocalGitState; + onDivergedBranch?: GitHandoffApplyInput["onDivergedBranch"]; + }, +): Promise { + const captureTracker = new GitHandoffTracker({ + repositoryPath: repos.cloudRepo, + }); + const capture = await captureTracker.captureForHandoff( + options?.captureState ?? repos.localGitState, + ); + + const applyTracker = new GitHandoffTracker({ + repositoryPath: repos.localRepo, + }); + + try { + await applyTracker.applyFromHandoff({ + checkpoint: capture.checkpoint, + headPackPath: capture.headPack?.path, + indexPath: capture.indexFile.path, + localGitState: options?.applyState ?? repos.localGitState, + onDivergedBranch: options?.onDivergedBranch, + }); + } catch (error) { + await cleanupCapture(capture); + throw error; + } + + return capture; +} + +describe("GitHandoffTracker", () => { + it("captures and reapplies head and index state from local files", async () => { + await withRepos(async (repos) => { + await makeCloudChanges(repos.cloudRepo, repos.cloudGit); + const capture = await captureAndApply(repos); + + try { + await mirrorWorktreeFiles(repos.cloudRepo, repos.localRepo, [ + "tracked.txt", + "unstaged.txt", + "untracked.txt", + ]); + + expect((await repos.localGit.revparse(["HEAD"])).trim()).toBe( + capture.checkpoint.head, + ); + expect( + (await repos.localGit.revparse(["--abbrev-ref", "HEAD"])).trim(), + ).toBe(repos.branch); + expect( + await readFile(path.join(repos.localRepo, "committed.txt"), "utf-8"), + ).toBe("cloud commit\n"); + + const status = await repos.localGit.raw(["status", "--porcelain"]); + expect(status).toContain("M tracked.txt"); + expect(status).toContain(" M unstaged.txt"); + expect(status).toContain("?? untracked.txt"); + } finally { + await cleanupCapture(capture); + } + }); + }, 15000); + + it("prompts before resetting a diverged local branch", async () => { + await withRepos(async (repos) => { + await writeFile( + path.join(repos.localRepo, "local-only.txt"), + "local commit\n", + ); + await repos.localGit.add(["local-only.txt"]); + await repos.localGit.commit("Local only"); + const localHead = (await repos.localGit.revparse(["HEAD"])).trim(); + + await writeFile( + path.join(repos.cloudRepo, "cloud-only.txt"), + "cloud commit\n", + ); + await repos.cloudGit.add(["cloud-only.txt"]); + await repos.cloudGit.commit("Cloud only"); + + const captureTracker = new GitHandoffTracker({ + repositoryPath: repos.cloudRepo, + }); + const capture = await captureTracker.captureForHandoff({ + ...repos.localGitState, + head: localHead, + upstreamHead: null, + }); + + const confirm = vi.fn().mockResolvedValue(false); + const applyTracker = new GitHandoffTracker({ + repositoryPath: repos.localRepo, + }); + + try { + await expect( + applyTracker.applyFromHandoff({ + checkpoint: capture.checkpoint, + headPackPath: capture.headPack?.path, + indexPath: capture.indexFile.path, + localGitState: { + ...repos.localGitState, + head: localHead, + upstreamHead: null, + }, + onDivergedBranch: confirm, + }), + ).rejects.toThrow("Handoff aborted"); + + expect(confirm).toHaveBeenCalledWith( + expect.objectContaining({ + branch: repos.branch, + cloudHead: capture.checkpoint.head, + }), + ); + expect( + ( + await repos.localGit.revparse([`refs/heads/${repos.branch}`]) + ).trim(), + ).not.toBe(capture.checkpoint.head); + } finally { + await cleanupCapture(capture); + } + }); + }, 15000); + + it("preserves existing local upstream config", async () => { + await withRepos(async (repos) => { + await repos.localGit.raw([ + "remote", + "set-url", + "origin", + "git@github.com:local/repo.git", + ]); + await repos.localGit.raw([ + "config", + `branch.${repos.branch}.remote`, + "origin", + ]); + await repos.localGit.raw([ + "config", + `branch.${repos.branch}.merge`, + `refs/heads/${repos.branch}`, + ]); + + await repos.cloudGit.addRemote( + "cloud-origin", + "https://example.com/cloud.git", + ); + await repos.cloudGit.raw([ + "config", + `branch.${repos.branch}.remote`, + "cloud-origin", + ]); + await repos.cloudGit.raw([ + "config", + `branch.${repos.branch}.merge`, + `refs/heads/${repos.branch}`, + ]); + + await writeFile( + path.join(repos.cloudRepo, "cloud-only.txt"), + "cloud commit\n", + ); + await repos.cloudGit.add(["cloud-only.txt"]); + await repos.cloudGit.commit("Cloud only"); + + const capture = await captureAndApply(repos, { + captureState: { + ...repos.localGitState, + upstreamHead: null, + }, + }); + + try { + expect( + ( + await repos.localGit.raw([ + "config", + "--get", + `branch.${repos.branch}.remote`, + ]) + ).trim(), + ).toBe("origin"); + expect( + (await repos.localGit.raw(["remote", "get-url", "origin"])).trim(), + ).toBe("git@github.com:local/repo.git"); + } finally { + await cleanupCapture(capture); + } + }); + }, 15000); + + it("adopts cloud upstream when the local branch has none", async () => { + await withRepos(async (repos) => { + await repos.localGit + .raw(["config", "--unset-all", `branch.${repos.branch}.remote`]) + .catch(() => {}); + await repos.localGit + .raw(["config", "--unset-all", `branch.${repos.branch}.merge`]) + .catch(() => {}); + await repos.localGit.removeRemote("origin"); + + await repos.cloudGit.addRemote( + "cloud-origin", + "https://example.com/cloud.git", + ); + await repos.cloudGit.raw([ + "config", + `branch.${repos.branch}.remote`, + "cloud-origin", + ]); + await repos.cloudGit.raw([ + "config", + `branch.${repos.branch}.merge`, + `refs/heads/${repos.branch}`, + ]); + + await writeFile( + path.join(repos.cloudRepo, "cloud-only.txt"), + "cloud commit\n", + ); + await repos.cloudGit.add(["cloud-only.txt"]); + await repos.cloudGit.commit("Cloud only"); + + const capture = await captureAndApply(repos, { + captureState: { + ...repos.localGitState, + upstreamHead: null, + upstreamRemote: null, + upstreamMergeRef: null, + }, + applyState: { + ...repos.localGitState, + upstreamRemote: null, + upstreamMergeRef: null, + }, + }); + + try { + expect( + ( + await repos.localGit.raw([ + "config", + "--get", + `branch.${repos.branch}.remote`, + ]) + ).trim(), + ).toBe("cloud-origin"); + expect( + ( + await repos.localGit.raw(["remote", "get-url", "cloud-origin"]) + ).trim(), + ).toBe("https://example.com/cloud.git"); + } finally { + await cleanupCapture(capture); + } + }); + }, 15000); +}); diff --git a/packages/git/src/handoff.ts b/packages/git/src/handoff.ts new file mode 100644 index 000000000..a01343c31 --- /dev/null +++ b/packages/git/src/handoff.ts @@ -0,0 +1,639 @@ +import { spawn } from "node:child_process"; +import { copyFile, mkdir, readFile, rm, stat } from "node:fs/promises"; +import path from "node:path"; +import type { SagaLogger } from "@posthog/shared"; +import { createGitClient, type GitClient } from "./client"; +import { CaptureCheckpointSaga, deleteCheckpoint } from "./sagas/checkpoint"; + +const HANDOFF_HEAD_REF_PREFIX = "refs/posthog-code-handoff/head/"; +const CHECKPOINT_REF_PREFIX = "refs/posthog-code-checkpoint/"; + +export interface HandoffLocalGitState { + head: string | null; + branch: string | null; + upstreamHead: string | null; + upstreamRemote: string | null; + upstreamMergeRef: string | null; +} + +export interface GitHandoffCheckpoint { + checkpointId: string; + commit: string; + checkpointRef: string; + headRef?: string; + head: string | null; + branch: string | null; + indexTree: string; + worktreeTree: string; + timestamp: string; + upstreamRemote: string | null; + upstreamMergeRef: string | null; + remoteUrl: string | null; +} + +export interface GitHandoffArtifactFile { + path: string; + rawBytes: number; +} + +export interface GitHandoffCaptureResult { + checkpoint: GitHandoffCheckpoint; + headPack?: GitHandoffArtifactFile; + indexFile: GitHandoffArtifactFile; + totalBytes: number; +} + +export interface GitHandoffApplyInput { + checkpoint: GitHandoffCheckpoint; + headPackPath?: string; + indexPath?: string; + localGitState?: HandoffLocalGitState; + onDivergedBranch?: ( + divergence: GitHandoffBranchDivergence, + ) => Promise; +} + +export interface GitHandoffApplyResult { + packBytes: number; + indexBytes: number; + totalBytes: number; +} + +export interface GitHandoffBranchDivergence { + branch: string; + localHead: string; + cloudHead: string; +} + +export interface GitHandoffTrackerConfig { + repositoryPath: string; + logger?: SagaLogger; +} + +interface GitTrackingMetadata { + upstreamRemote: string | null; + upstreamMergeRef: string | null; + remoteUrl: string | null; +} + +type GitBranchRestoreStatus = + | { kind: "missing" } + | { kind: "match" } + | { kind: "fast_forward" } + | { kind: "diverged"; divergence: GitHandoffBranchDivergence }; + +export class GitHandoffTracker { + private repositoryPath: string; + private logger?: SagaLogger; + + constructor(config: GitHandoffTrackerConfig) { + this.repositoryPath = config.repositoryPath; + this.logger = config.logger; + } + + async captureForHandoff( + localGitState?: HandoffLocalGitState, + ): Promise { + const captureSaga = new CaptureCheckpointSaga(this.logger); + const result = await captureSaga.run({ baseDir: this.repositoryPath }); + if (!result.success) { + throw new Error( + `Failed to capture checkpoint at step '${result.failedStep}': ${result.error}`, + ); + } + + const checkpoint = result.data; + const git = createGitClient(this.repositoryPath); + const tempDir = await this.getTempDir(git); + const checkpointRef = `${CHECKPOINT_REF_PREFIX}${checkpoint.checkpointId}`; + const shouldIncludeHead = + !!checkpoint.head && checkpoint.head !== localGitState?.head; + const headRef = shouldIncludeHead + ? `${HANDOFF_HEAD_REF_PREFIX}${checkpoint.checkpointId}` + : undefined; + const packPrefix = path.join(tempDir, checkpoint.checkpointId); + + try { + const [headPack, indexFile, tracking] = await Promise.all([ + shouldIncludeHead && checkpoint.head + ? this.captureHeadPack(packPrefix, checkpoint.head) + : Promise.resolve(undefined), + this.copyIndexFile(git, checkpoint.checkpointId), + getTrackingMetadata(git, checkpoint.branch), + ]); + + return { + checkpoint: { + checkpointId: checkpoint.checkpointId, + commit: checkpoint.commit, + checkpointRef, + headRef, + head: checkpoint.head, + branch: checkpoint.branch, + indexTree: checkpoint.indexTree, + worktreeTree: checkpoint.worktreeTree, + timestamp: checkpoint.timestamp, + upstreamRemote: tracking.upstreamRemote, + upstreamMergeRef: tracking.upstreamMergeRef, + remoteUrl: tracking.remoteUrl, + }, + headPack, + indexFile, + totalBytes: (headPack?.rawBytes ?? 0) + indexFile.rawBytes, + }; + } finally { + await deleteCheckpoint(git, checkpoint.checkpointId).catch(() => {}); + } + } + + async applyFromHandoff( + input: GitHandoffApplyInput, + ): Promise { + const { + checkpoint, + headPackPath, + indexPath, + localGitState, + onDivergedBranch, + } = input; + const git = createGitClient(this.repositoryPath); + + if (headPackPath) { + await this.unpackPackFile(headPackPath); + } + + if (checkpoint.branch && checkpoint.head) { + const branchStatus = await this.resolveBranchRestoreStatus( + git, + checkpoint.branch, + checkpoint.head, + localGitState, + ); + const tracking = this.getPreferredTracking(localGitState, checkpoint); + + if ( + branchStatus.kind === "diverged" && + !(await onDivergedBranch?.(branchStatus.divergence)) + ) { + throw new Error( + `Handoff aborted: local branch '${checkpoint.branch}' has diverged`, + ); + } + + await this.checkoutBranchAtHead(git, checkpoint.branch, checkpoint.head); + + if (this.shouldRestoreTracking(branchStatus, localGitState, tracking)) { + await this.ensureRemoteForTracking(git, tracking); + await this.configureUpstream(git, checkpoint.branch, tracking); + } + } else if (checkpoint.head) { + await git.checkout(checkpoint.head); + } + + if (indexPath) { + await this.restoreIndexFile(git, indexPath); + } + + const packBytes = headPackPath ? await this.getFileSize(headPackPath) : 0; + const indexBytes = indexPath ? await this.getFileSize(indexPath) : 0; + + return { + packBytes, + indexBytes, + totalBytes: packBytes + indexBytes, + }; + } + + private async captureHeadPack( + packPrefix: string, + headCommit: string, + ): Promise { + const hash = await this.runGitWithInput( + ["pack-objects", packPrefix, "--revs"], + `${headCommit}\n`, + ); + const packPath = `${packPrefix}-${hash.trim()}.pack`; + const rawBytes = await this.getFileSize(packPath); + await rm(`${packPath}.idx`, { force: true }).catch(() => {}); + return { path: packPath, rawBytes }; + } + + private async copyIndexFile( + git: GitClient, + checkpointId: string, + ): Promise { + const indexPath = await this.getGitPath(git, "index"); + const tempDir = await this.getTempDir(git); + const copiedIndexPath = path.join(tempDir, `${checkpointId}.index`); + await copyFile(indexPath, copiedIndexPath); + return { + path: copiedIndexPath, + rawBytes: await this.getFileSize(copiedIndexPath), + }; + } + + private async restoreIndexFile( + git: GitClient, + indexPath: string, + ): Promise { + const gitIndexPath = await this.getGitPath(git, "index"); + await copyFile(indexPath, gitIndexPath); + } + + private async unpackPackFile(packPath: string): Promise { + const content = await readFile(packPath); + await this.runGitWithBuffer(["unpack-objects", "-r"], content); + } + + private getPreferredTracking( + localGitState: HandoffLocalGitState | undefined, + checkpoint: GitHandoffCheckpoint, + ): GitTrackingMetadata { + const state = localGitState; + if (state && hasTrackingConfig(state)) { + return { + upstreamRemote: state.upstreamRemote ?? null, + upstreamMergeRef: state.upstreamMergeRef ?? null, + remoteUrl: + state.upstreamRemote && + state.upstreamRemote === checkpoint.upstreamRemote + ? checkpoint.remoteUrl + : null, + }; + } + + return { + upstreamRemote: checkpoint.upstreamRemote, + upstreamMergeRef: checkpoint.upstreamMergeRef, + remoteUrl: checkpoint.remoteUrl, + }; + } + + private shouldRestoreTracking( + branchStatus: GitBranchRestoreStatus, + localGitState: HandoffLocalGitState | undefined, + tracking: GitTrackingMetadata, + ): boolean { + return ( + branchStatus.kind === "missing" || + (!hasTrackingConfig(localGitState) && + (tracking.upstreamRemote !== null || + tracking.upstreamMergeRef !== null)) + ); + } + + private async ensureRemoteForTracking( + git: GitClient, + tracking: GitTrackingMetadata, + ): Promise { + if (!tracking.upstreamRemote || !tracking.remoteUrl) return; + + const remotes = await git.getRemotes(true); + const existing = remotes.find( + (remote) => remote.name === tracking.upstreamRemote, + ); + + if (!existing) { + await git.addRemote(tracking.upstreamRemote, tracking.remoteUrl); + } + } + + private async configureUpstream( + git: GitClient, + branch: string, + tracking: GitTrackingMetadata, + ): Promise { + if (tracking.upstreamRemote) { + await git.raw([ + "config", + `branch.${branch}.remote`, + tracking.upstreamRemote, + ]); + } + + if (tracking.upstreamMergeRef) { + await git.raw([ + "config", + `branch.${branch}.merge`, + tracking.upstreamMergeRef, + ]); + } + } + + private async resolveBranchRestoreStatus( + git: GitClient, + branch: string, + cloudHead: string, + localGitState?: HandoffLocalGitState, + ): Promise { + const branchRef = `refs/heads/${branch}`; + const branchExists = await this.refExists(git, branchRef); + if (!branchExists) { + return { kind: "missing" }; + } + + const currentBranchHead = (await git.revparse([branchRef])).trim(); + const candidateHeads = [ + currentBranchHead, + ...(localGitState?.branch === branch && localGitState.head + ? [localGitState.head] + : []), + ].filter((value, index, array) => array.indexOf(value) === index); + + if (candidateHeads.every((head) => head === cloudHead)) { + return { kind: "match" }; + } + + const nonAncestorHead = await this.findNonAncestorHead( + git, + candidateHeads, + cloudHead, + ); + if (!nonAncestorHead) { + return { kind: "fast_forward" }; + } + + return { + kind: "diverged", + divergence: { + branch, + localHead: nonAncestorHead, + cloudHead, + }, + }; + } + + private async findNonAncestorHead( + _git: GitClient, + heads: string[], + cloudHead: string, + ): Promise { + for (const head of heads) { + if (head === cloudHead) { + continue; + } + if (!(await this.isAncestor(head, cloudHead))) { + return head; + } + } + return null; + } + + private async checkoutBranchAtHead( + git: GitClient, + branch: string, + head: string, + ): Promise { + const currentBranch = await getCurrentBranchName(git); + if (currentBranch === branch) { + await git.reset(["--hard", head]); + return; + } + + const branchRef = `refs/heads/${branch}`; + if (await this.refExists(git, branchRef)) { + await git.branch(["-f", branch, head]); + await git.checkout(branch); + return; + } + + await git.checkout(["-b", branch, head]); + } + + private async refExists(git: GitClient, ref: string): Promise { + try { + await git.revparse(["--verify", ref]); + return true; + } catch { + return false; + } + } + + private async isAncestor( + ancestor: string, + descendant: string, + ): Promise { + const exitCode = await this.runGitProcessAllowingFailure([ + "merge-base", + "--is-ancestor", + ancestor, + descendant, + ]); + return exitCode === 0; + } + + private async getTempDir(git: GitClient): Promise { + const raw = await git.raw(["rev-parse", "--git-common-dir"]); + const commonDir = raw.trim() || ".git"; + const resolved = path.isAbsolute(commonDir) + ? commonDir + : path.resolve(this.repositoryPath, commonDir); + const tempDir = path.join(resolved, "posthog-code-tmp"); + await mkdir(tempDir, { recursive: true }); + return tempDir; + } + + private async getGitPath(git: GitClient, gitPath: string): Promise { + const raw = await git.raw(["rev-parse", "--git-path", gitPath]); + const resolved = raw.trim(); + return path.isAbsolute(resolved) + ? resolved + : path.resolve(this.repositoryPath, resolved); + } + + private async getFileSize(filePath: string): Promise { + return (await stat(filePath)).size; + } + + private async runGitWithInput( + args: string[], + input: string, + ): Promise { + const { stdout } = await this.runGitProcess(args, input); + return stdout; + } + + private async runGitWithBuffer(args: string[], input: Buffer): Promise { + await this.runGitProcess(args, input); + } + + private async runGitProcessAllowingFailure(args: string[]): Promise { + return new Promise((resolve, reject) => { + const child = spawn("git", args, { + cwd: this.repositoryPath, + stdio: ["ignore", "ignore", "pipe"], + }); + + let stderr = ""; + child.stderr.on("data", (chunk: Buffer | string) => { + stderr += chunk.toString(); + }); + child.on("error", reject); + child.on("close", (code) => { + if (code === null) { + reject(new Error(`git ${args.join(" ")} exited unexpectedly`)); + return; + } + if (code > 1) { + reject( + new Error( + stderr || `git ${args.join(" ")} failed with code ${code}`, + ), + ); + return; + } + resolve(code); + }); + }); + } + + private runGitProcess( + args: string[], + input: string | Buffer, + ): Promise<{ stdout: string; stderr: string }> { + return new Promise((resolve, reject) => { + const child = spawn("git", args, { + cwd: this.repositoryPath, + stdio: "pipe", + }); + + let stdout = ""; + let stderr = ""; + + child.stdout.on("data", (chunk: Buffer | string) => { + stdout += chunk.toString(); + }); + child.stderr.on("data", (chunk: Buffer | string) => { + stderr += chunk.toString(); + }); + child.on("error", reject); + child.on("close", (code) => { + if (code === 0) { + resolve({ stdout, stderr }); + return; + } + reject( + new Error(stderr || `git ${args.join(" ")} failed with code ${code}`), + ); + }); + + child.stdin.end(input); + }); + } +} + +export async function readHandoffLocalGitState( + repositoryPath: string, +): Promise { + const git = createGitClient(repositoryPath); + const head = await readCurrentHead(git); + const branch = await getCurrentBranchName(git); + const tracking = await getTrackingMetadata(git, branch); + const upstreamHead = + tracking.upstreamRemote && tracking.upstreamMergeRef + ? await resolveUpstreamHead( + git, + tracking.upstreamRemote, + tracking.upstreamMergeRef, + ) + : null; + + return { + head, + branch, + upstreamHead, + upstreamRemote: tracking.upstreamRemote, + upstreamMergeRef: tracking.upstreamMergeRef, + }; +} + +async function readCurrentHead(git: GitClient): Promise { + try { + return (await git.revparse(["HEAD"])).trim() || null; + } catch { + return null; + } +} + +async function getCurrentBranchName(git: GitClient): Promise { + try { + const raw = await git.revparse(["--abbrev-ref", "HEAD"]); + const branch = raw.trim(); + return branch === "HEAD" ? null : branch; + } catch { + return null; + } +} + +async function getTrackingMetadata( + git: GitClient, + branch: string | null, +): Promise { + if (!branch) { + return { + upstreamRemote: null, + upstreamMergeRef: null, + remoteUrl: null, + }; + } + + const upstreamRemote = await getGitConfigValue( + git, + `branch.${branch}.remote`, + ); + const upstreamMergeRef = await getGitConfigValue( + git, + `branch.${branch}.merge`, + ); + const remoteUrl = upstreamRemote + ? await getRemoteUrl(git, upstreamRemote) + : null; + + return { upstreamRemote, upstreamMergeRef, remoteUrl }; +} + +async function getGitConfigValue( + git: GitClient, + key: string, +): Promise { + try { + const value = await git.raw(["config", "--get", key]); + return value.trim() || null; + } catch { + return null; + } +} + +async function getRemoteUrl( + git: GitClient, + remote: string, +): Promise { + try { + const value = await git.remote(["get-url", remote]); + return typeof value === "string" ? value.trim() || null : null; + } catch { + return null; + } +} + +async function resolveUpstreamHead( + git: GitClient, + upstreamRemote: string, + upstreamMergeRef: string, +): Promise { + const upstreamBranch = upstreamMergeRef.replace("refs/heads/", ""); + try { + return ( + (await git.revparse([`${upstreamRemote}/${upstreamBranch}`])).trim() || + null + ); + } catch { + return null; + } +} + +function hasTrackingConfig( + localGitState: HandoffLocalGitState | undefined, +): boolean { + return !!(localGitState?.upstreamRemote || localGitState?.upstreamMergeRef); +} diff --git a/packages/git/src/sagas/checkpoint.ts b/packages/git/src/sagas/checkpoint.ts index d6d211ad8..b95660c7b 100644 --- a/packages/git/src/sagas/checkpoint.ts +++ b/packages/git/src/sagas/checkpoint.ts @@ -106,6 +106,7 @@ export class CaptureCheckpointSaga extends GitSaga< const rawCommit = await commitGit.raw([ "commit-tree", metaTree.trim(), + ...(headInfo.head ? ["-p", headInfo.head] : []), "-m", message, ]);