diff --git a/mcp/client.go b/mcp/client.go index dc7bef1c..4860b62e 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -340,7 +340,8 @@ func (cs *ClientSession) ID() string { // Close is idempotent and concurrency safe. func (cs *ClientSession) Close() error { // Note: keepaliveCancel access is safe without a mutex because: - // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 1. keepaliveCancel is only written once during Client.Connect (through startKeepalive), + // which happens before any code that may call Close from another goroutine // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines // 3. The keepalive goroutine calls Close on ping failure, but this is safe since // Close is idempotent and conn.Close() handles concurrent calls correctly diff --git a/mcp/server.go b/mcp/server.go index 97b5d446..2357959c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1031,6 +1031,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp s.opts.Logger.Error("server connect error", "error", err) return nil, err } + + // Start keepalive before returning the session to avoid race conditions with Close. + // This is safe because the spec allows sending pings before initialization (see ServerSession.handle for details). + if s.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } + return ss, nil } @@ -1058,9 +1065,6 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar ss.server.opts.Logger.Error("duplicate initialized notification") return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } - if ss.server.opts.KeepAlive > 0 { - ss.startKeepalive(ss.server.opts.KeepAlive) - } if h := ss.server.opts.InitializedHandler; h != nil { h(ctx, serverRequestFor(ss, params)) } @@ -1110,7 +1114,7 @@ type ServerSession struct { server *Server conn *jsonrpc2.Connection mcpConn Connection - keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + keepaliveCancel context.CancelFunc mu sync.Mutex state ServerSessionState @@ -1504,7 +1508,8 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelPara func (ss *ServerSession) Close() error { if ss.keepaliveCancel != nil { // Note: keepaliveCancel access is safe without a mutex because: - // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 1. keepaliveCancel is only written once during Server.Connect (through startKeepalive), + // which happens before any code that may call Close from another goroutine // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines // 3. The keepalive goroutine calls Close on ping failure, but this is safe since // Close is idempotent and conn.Close() handles concurrent calls correctly diff --git a/mcp/server_test.go b/mcp/server_test.go index 85c3990d..1312e1d9 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -508,54 +508,6 @@ func TestServerAddResourceTemplate(t *testing.T) { } } -// TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once, -// ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism. -func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { - // Set KeepAlive to a long duration to ensure the keepalive - // goroutine stays alive for the duration of the test without actually sending - // ping requests, since we don't have a real client connection established. - server := NewServer(testImpl, &ServerOptions{KeepAlive: 5 * time.Second}) - ss := &ServerSession{server: server} - - // 1. Initialize the session. - _, err := ss.initialize(context.Background(), &InitializeParams{}) - if err != nil { - t.Fatalf("ServerSession initialize failed: %v", err) - } - - // 2. Call 'initialized' for the first time. This should start the keepalive mechanism. - _, err = ss.initialized(context.Background(), &InitializedParams{}) - if err != nil { - t.Fatalf("First initialized call failed: %v", err) - } - if ss.keepaliveCancel == nil { - t.Fatalf("expected ServerSession.keepaliveCancel to be set after the first call of initialized") - } - - // Save the cancel function and use defer to ensure resources are cleaned up. - firstCancel := ss.keepaliveCancel - defer firstCancel() - - // 3. Manually set the field to nil. - // Do this to facilitate the test's core assertion. The goal is to verify that - // 'ss.keepaliveCancel' is not assigned a second time. By setting it to nil, - // we can easily check after the next call if a new keepalive goroutine was started. - ss.keepaliveCancel = nil - - // 4. Call 'initialized' for the second time. This should return an error. - _, err = ss.initialized(context.Background(), &InitializedParams{}) - if err == nil { - t.Fatalf("Expected 'duplicate initialized received' error on second call, got nil") - } - - // 5. Re-check the field to ensure it remains nil. - // Since 'initialized' correctly returned an error and did not call - // 'startKeepalive', the field should remain unchanged. - if ss.keepaliveCancel != nil { - t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized") - } -} - // panicks reports whether f() panics. func panics(f func()) (b bool) { defer func() { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a275d3cd..36002775 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -363,6 +363,34 @@ func TestStreamableServerShutdown(t *testing.T) { } } +// TestStreamableStatelessKeepaliveRace verifies that there is no data race between +// ServerSession.startKeepalive and ServerSession.Close in stateless servers. +func TestStreamableStatelessKeepaliveRace(t *testing.T) { + ctx := context.Background() + server := NewServer(testImpl, &ServerOptions{KeepAlive: time.Hour}) + AddTool(server, &Tool{Name: "greet"}, sayHi) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + for range 50 { + cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + }, nil) + if err != nil { + t.Fatalf("NewClient() failed: %v", err) + } + _, _ = cs.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "world"}, + }) + _ = cs.Close() + } +} + // TestClientReplay verifies that the client can recover from a mid-stream // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network