diff --git a/pkg/raft/node.go b/pkg/raft/node.go index 329a4b2c9f..b22ce92bed 100644 --- a/pkg/raft/node.go +++ b/pkg/raft/node.go @@ -257,6 +257,18 @@ func (n *Node) NodeID() string { return n.config.NodeID } +// LeaderID returns the server ID of the current cluster leader. +// Returns an empty string if the receiver is nil, raft is uninitialized, or no +// leader has been elected yet. The value may be momentarily stale between raft +// leadership changes; callers that need a strong guarantee should cross-check +// with HasQuorum. +func (n *Node) LeaderID() string { + if n == nil || n.raft == nil { + return "" + } + return n.leaderID() +} + func (n *Node) leaderID() string { _, id := n.raft.LeaderWithID() return string(id) diff --git a/pkg/rpc/server/http.go b/pkg/rpc/server/http.go index 6cffdb2739..a721ce0b82 100644 --- a/pkg/rpc/server/http.go +++ b/pkg/rpc/server/http.go @@ -139,11 +139,16 @@ func RegisterCustomHTTPEndpoints(mux *http.ServeMux, s store.Store, pm p2p.P2PRP http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } + leaderID := raftNode.LeaderID() + isLeader := raftNode.IsLeader() + if leaderID != "" { + isLeader = leaderID == raftNode.NodeID() + } rsp := struct { IsLeader bool `json:"is_leader"` NodeID string `json:"node_id"` }{ - IsLeader: raftNode.IsLeader(), + IsLeader: isLeader, NodeID: raftNode.NodeID(), } w.Header().Set("Content-Type", "application/json") diff --git a/pkg/rpc/server/http_test.go b/pkg/rpc/server/http_test.go index 4fa2e92e17..23e0e180f3 100644 --- a/pkg/rpc/server/http_test.go +++ b/pkg/rpc/server/http_test.go @@ -1,6 +1,7 @@ package server import ( + "encoding/json" "io" "net/http" "net/http/httptest" @@ -45,6 +46,103 @@ func TestRegisterCustomHTTPEndpoints(t *testing.T) { mockStore.AssertExpectations(t) } +type testRaftNodeSource struct { + isLeader bool + leaderID string + nodeID string +} + +func (t testRaftNodeSource) IsLeader() bool { + return t.isLeader +} + +func (t testRaftNodeSource) LeaderID() string { + return t.leaderID +} + +func (t testRaftNodeSource) NodeID() string { + return t.nodeID +} + +func TestRegisterCustomHTTPEndpoints_RaftNodeStatus(t *testing.T) { + type bodyShape struct { + IsLeader bool `json:"is_leader"` + NodeID string `json:"node_id"` + } + + cases := []struct { + name string + node testRaftNodeSource + method string + wantStatus int + wantIsLeader bool + wantNodeID string + skipBodyDecode bool + }{ + { + // leaderID == nodeID: handler derives is_leader=true from LeaderID(), + // regardless of the IsLeader() field on testRaftNodeSource. + name: "leader matches — is_leader true", + node: testRaftNodeSource{leaderID: "node-a", nodeID: "node-a"}, + method: http.MethodGet, + wantStatus: http.StatusOK, + wantIsLeader: true, + wantNodeID: "node-a", + }, + { + // leaderID != nodeID: handler derives is_leader=false. + name: "leader differs — is_leader false", + node: testRaftNodeSource{leaderID: "node-b", nodeID: "node-a"}, + method: http.MethodGet, + wantStatus: http.StatusOK, + wantIsLeader: false, + wantNodeID: "node-a", + }, + { + // empty leaderID: fallback — is_leader=false (no elected leader known). + name: "empty leaderID fallback — is_leader false", + node: testRaftNodeSource{leaderID: "", nodeID: "node-a"}, + method: http.MethodGet, + wantStatus: http.StatusOK, + wantIsLeader: false, + wantNodeID: "node-a", + }, + { + name: "non-GET method — 405", + node: testRaftNodeSource{}, + method: http.MethodPost, + wantStatus: http.StatusMethodNotAllowed, + skipBodyDecode: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mux := http.NewServeMux() + RegisterCustomHTTPEndpoints(mux, nil, nil, config.DefaultConfig(), nil, zerolog.Nop(), tc.node) + + ts := httptest.NewServer(mux) + t.Cleanup(ts.Close) + + req, err := http.NewRequest(tc.method, ts.URL+"/raft/node", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) //nolint:gosec // test-only request to httptest server + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + + require.Equal(t, tc.wantStatus, resp.StatusCode) + if tc.skipBodyDecode { + return + } + + var body bodyShape + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, tc.wantIsLeader, body.IsLeader) + assert.Equal(t, tc.wantNodeID, body.NodeID) + }) + } +} + func TestHealthReady_aggregatorBlockDelay(t *testing.T) { logger := zerolog.Nop() diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index ce747ae37d..419f8b6631 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -368,6 +368,7 @@ func (p *P2PServer) GetNetInfo( type RaftNodeSource interface { IsLeader() bool + LeaderID() string NodeID() string }