diff --git a/packages/agent/src/server/agent-server.test.ts b/packages/agent/src/server/agent-server.test.ts index 2b69f524c..3ba8943d7 100644 --- a/packages/agent/src/server/agent-server.test.ts +++ b/packages/agent/src/server/agent-server.test.ts @@ -252,6 +252,30 @@ describe("AgentServer HTTP Mode", () => { const body = await response.json(); expect(body.error).toBe("No active session for this run"); }); + + it("accepts structured user_message content", async () => { + await createServer().start(); + const token = createToken({ run_id: "different-run-id" }); + + const response = await fetch(`http://localhost:${port}/command`, { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "user_message", + params: { + content: [{ type: "text", text: "test" }], + }, + }), + }); + + expect(response.status).toBe(400); + const body = await response.json(); + expect(body.error).toBe("No active session for this run"); + }); }); describe("404 handling", () => { diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index 8f8fa25fe..2d14f84e0 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -1,3 +1,4 @@ +import type { ContentBlock } from "@agentclientprotocol/sdk"; import { ClientSideConnection, ndJsonStream, @@ -30,6 +31,11 @@ import type { import { AsyncMutex } from "../utils/async-mutex"; import { getLlmGatewayUrl } from "../utils/gateway"; import { Logger } from "../utils/logger"; +import { + deserializeCloudPrompt, + normalizeCloudPromptContent, + promptBlocksToText, +} from "./cloud-prompt"; import { type JwtPayload, JwtValidationError, validateJwt } from "./jwt"; import { jsonRpcRequestSchema, validateCommandParams } from "./schemas"; import type { AgentServerConfig } from "./types"; @@ -487,17 +493,20 @@ export class AgentServer { switch (method) { case POSTHOG_NOTIFICATIONS.USER_MESSAGE: case "user_message": { - const content = params.content as string; + const prompt = normalizeCloudPromptContent( + params.content as string | ContentBlock[], + ); + const promptPreview = promptBlocksToText(prompt); this.logger.info( - `Processing user message (detectedPrUrl=${this.detectedPrUrl ?? "none"}): ${content.substring(0, 100)}...`, + `Processing user message (detectedPrUrl=${this.detectedPrUrl ?? "none"}): ${promptPreview.substring(0, 100)}...`, ); this.session.logWriter.resetTurnMessages(this.session.payload.run_id); const result = await this.session.clientConnection.prompt({ sessionId: this.session.acpSessionId, - prompt: [{ type: "text", text: content }], + prompt, ...(this.detectedPrUrl && { _meta: { prContext: @@ -837,24 +846,33 @@ export class AgentServer { const initialPromptOverride = taskRun ? this.getInitialPromptOverride(taskRun) : null; - const initialPrompt = initialPromptOverride ?? task.description; + const pendingUserPrompt = this.getPendingUserPrompt(taskRun); + let initialPrompt: ContentBlock[] = []; + if (pendingUserPrompt?.length) { + initialPrompt = pendingUserPrompt; + } else if (initialPromptOverride) { + initialPrompt = [{ type: "text", text: initialPromptOverride }]; + } else if (task.description) { + initialPrompt = [{ type: "text", text: task.description }]; + } - if (!initialPrompt) { + if (initialPrompt.length === 0) { this.logger.warn("Task has no description, skipping initial message"); return; } this.logger.info("Sending initial task message", { taskId: payload.task_id, - descriptionLength: initialPrompt.length, + descriptionLength: promptBlocksToText(initialPrompt).length, usedInitialPromptOverride: !!initialPromptOverride, + usedPendingUserMessage: !!pendingUserPrompt?.length, }); this.session.logWriter.resetTurnMessages(payload.run_id); const result = await this.session.clientConnection.prompt({ sessionId: this.session.acpSessionId, - prompt: [{ type: "text", text: initialPrompt }], + prompt: initialPrompt, }); this.logger.info("Initial task message completed", { @@ -886,38 +904,49 @@ export class AgentServer { this.resumeState.conversation, ); - // Read the pending user message from TaskRun state (set by the workflow + // Read the pending user prompt from TaskRun state (set by the workflow // when the user sends a follow-up message that triggers a resume). - const pendingUserMessage = this.getPendingUserMessage(taskRun); + const pendingUserPrompt = this.getPendingUserPrompt(taskRun); const sandboxContext = this.resumeState.snapshotApplied ? `The workspace environment (all files, packages, and code changes) has been fully restored from where you left off.` : `The workspace files from the previous session were not restored (the file snapshot may have expired), so you are starting with a fresh environment. Your conversation history is fully preserved below.`; - let resumePrompt: string; - if (pendingUserMessage) { - // Include the pending message as the user's new question so the agent - // responds to it directly instead of the generic resume context. - resumePrompt = - `You are resuming a previous conversation. ${sandboxContext}\n\n` + - `Here is the conversation history from the previous session:\n\n` + - `${conversationSummary}\n\n` + - `The user has sent a new message:\n\n` + - `${pendingUserMessage}\n\n` + - `Respond to the user's new message above. You have full context from the previous session.`; + let resumePromptBlocks: ContentBlock[]; + if (pendingUserPrompt?.length) { + resumePromptBlocks = [ + { + type: "text", + text: + `You are resuming a previous conversation. ${sandboxContext}\n\n` + + `Here is the conversation history from the previous session:\n\n` + + `${conversationSummary}\n\n` + + `The user has sent a new message:\n\n`, + }, + ...pendingUserPrompt, + { + type: "text", + text: "\n\nRespond to the user's new message above. You have full context from the previous session.", + }, + ]; } else { - resumePrompt = - `You are resuming a previous conversation. ${sandboxContext}\n\n` + - `Here is the conversation history from the previous session:\n\n` + - `${conversationSummary}\n\n` + - `Continue from where you left off. The user is waiting for your response.`; + resumePromptBlocks = [ + { + type: "text", + text: + `You are resuming a previous conversation. ${sandboxContext}\n\n` + + `Here is the conversation history from the previous session:\n\n` + + `${conversationSummary}\n\n` + + `Continue from where you left off. The user is waiting for your response.`, + }, + ]; } this.logger.info("Sending resume message", { taskId: payload.task_id, conversationTurns: this.resumeState.conversation.length, - promptLength: resumePrompt.length, - hasPendingUserMessage: !!pendingUserMessage, + promptLength: promptBlocksToText(resumePromptBlocks).length, + hasPendingUserMessage: !!pendingUserPrompt?.length, snapshotApplied: this.resumeState.snapshotApplied, }); @@ -928,7 +957,7 @@ export class AgentServer { const result = await this.session.clientConnection.prompt({ sessionId: this.session.acpSessionId, - prompt: [{ type: "text", text: resumePrompt }], + prompt: resumePromptBlocks, }); this.logger.info("Resume message completed", { @@ -1013,7 +1042,7 @@ export class AgentServer { return trimmed.length > 0 ? trimmed : null; } - private getPendingUserMessage(taskRun: TaskRun | null): string | null { + private getPendingUserPrompt(taskRun: TaskRun | null): ContentBlock[] | null { if (!taskRun) return null; const state = taskRun.state as Record | undefined; const message = state?.pending_user_message; @@ -1021,8 +1050,8 @@ export class AgentServer { return null; } - const trimmed = message.trim(); - return trimmed.length > 0 ? trimmed : null; + const prompt = deserializeCloudPrompt(message); + return prompt.length > 0 ? prompt : null; } private getResumeRunId(taskRun: TaskRun | null): string | null { diff --git a/packages/agent/src/server/cloud-prompt.ts b/packages/agent/src/server/cloud-prompt.ts new file mode 100644 index 000000000..c85370dbc --- /dev/null +++ b/packages/agent/src/server/cloud-prompt.ts @@ -0,0 +1,13 @@ +import type { ContentBlock } from "@agentclientprotocol/sdk"; +import { deserializeCloudPrompt, promptBlocksToText } from "@posthog/shared"; + +export { deserializeCloudPrompt, promptBlocksToText }; + +export function normalizeCloudPromptContent( + content: string | ContentBlock[], +): ContentBlock[] { + if (typeof content === "string") { + return deserializeCloudPrompt(content); + } + return content; +} diff --git a/packages/agent/src/server/question-relay.test.ts b/packages/agent/src/server/question-relay.test.ts index bdac3891f..5f73abfd3 100644 --- a/packages/agent/src/server/question-relay.test.ts +++ b/packages/agent/src/server/question-relay.test.ts @@ -371,6 +371,53 @@ describe("Question relay", () => { }); describe("sendInitialTaskMessage prompt source", () => { + it("uses pending user prompt blocks when present", async () => { + vi.spyOn(server.posthogAPI, "getTask").mockResolvedValue({ + id: "test-task-id", + title: "t", + description: "original task description", + } as unknown as Task); + vi.spyOn(server.posthogAPI, "getTaskRun").mockResolvedValue({ + id: "test-run-id", + task: "test-task-id", + state: { + pending_user_message: + '__twig_cloud_prompt_v1__:{"blocks":[{"type":"text","text":"read this attachment"},{"type":"resource","resource":{"uri":"attachment://test.txt","text":"hello from file","mimeType":"text/plain"}}]}', + }, + } as unknown as TaskRun); + + const promptSpy = vi.fn().mockResolvedValue({ stopReason: "max_tokens" }); + server.session = { + payload: TEST_PAYLOAD, + acpSessionId: "acp-session", + clientConnection: { prompt: promptSpy }, + logWriter: { + flushAll: vi.fn().mockResolvedValue(undefined), + getFullAgentResponse: vi.fn().mockReturnValue(null), + resetTurnMessages: vi.fn(), + flush: vi.fn().mockResolvedValue(undefined), + isRegistered: vi.fn().mockReturnValue(true), + }, + }; + + await server.sendInitialTaskMessage(TEST_PAYLOAD); + + expect(promptSpy).toHaveBeenCalledWith({ + sessionId: "acp-session", + prompt: [ + { type: "text", text: "read this attachment" }, + { + type: "resource", + resource: { + uri: "attachment://test.txt", + text: "hello from file", + mimeType: "text/plain", + }, + }, + ], + }); + }); + it("uses run state initial_prompt_override when present", async () => { vi.spyOn(server.posthogAPI, "getTask").mockResolvedValue({ id: "test-task-id", diff --git a/packages/agent/src/server/schemas.test.ts b/packages/agent/src/server/schemas.test.ts index 87b977aed..30efddc9f 100644 --- a/packages/agent/src/server/schemas.test.ts +++ b/packages/agent/src/server/schemas.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from "vitest"; -import { mcpServersSchema } from "./schemas"; +import { mcpServersSchema, validateCommandParams } from "./schemas"; describe("mcpServersSchema", () => { it("accepts a valid HTTP server", () => { @@ -115,3 +115,21 @@ describe("mcpServersSchema", () => { expect(result.success).toBe(false); }); }); + +describe("validateCommandParams", () => { + it("accepts structured user_message content arrays", () => { + const result = validateCommandParams("user_message", { + content: [{ type: "text", text: "hello" }], + }); + + expect(result.success).toBe(true); + }); + + it("rejects empty content array", () => { + const result = validateCommandParams("user_message", { + content: [], + }); + + expect(result.success).toBe(false); + }); +}); diff --git a/packages/agent/src/server/schemas.ts b/packages/agent/src/server/schemas.ts index 7eb4348a3..2f2528759 100644 --- a/packages/agent/src/server/schemas.ts +++ b/packages/agent/src/server/schemas.ts @@ -42,7 +42,10 @@ export const jsonRpcRequestSchema = z.object({ export type JsonRpcRequest = z.infer; export const userMessageParamsSchema = z.object({ - content: z.string().min(1, "Content is required"), + content: z.union([ + z.string().min(1, "Content is required"), + z.array(z.record(z.string(), z.unknown())).min(1, "Content is required"), + ]), }); export const commandParamsSchemas = {