From d1dd1898f8449bac3955cfd903f4758f2df4eeb5 Mon Sep 17 00:00:00 2001 From: Brett Lawson <> Date: Fri, 5 Jul 2024 12:38:34 -0700 Subject: [PATCH] ING-827: Added support for pinning queries to an endpoint. --- agentmanager_int_test.go | 9 ++++- basehttpcomponent.go | 32 ++++++++++++++++ querycomponent.go | 78 +++++++++++++++++++++++++++++--------- querycomponent_int_test.go | 47 ++++++++++++++++++++++- 4 files changed, 145 insertions(+), 21 deletions(-) diff --git a/agentmanager_int_test.go b/agentmanager_int_test.go index 1ff2ed98..107b0421 100644 --- a/agentmanager_int_test.go +++ b/agentmanager_int_test.go @@ -7,6 +7,7 @@ import ( "github.com/couchbase/gocbcorex" "github.com/couchbase/gocbcorex/cbmgmtx" + "github.com/couchbase/gocbcorex/cbqueryx" "github.com/stretchr/testify/assert" @@ -66,7 +67,9 @@ func TestOnDemandAgentManagerClose(t *testing.T) { require.NoError(t, err) _, err = agent.Query(context.Background(), &gocbcorex.QueryOptions{ - Statement: "SELECT 1=1", + QueryOptions: cbqueryx.QueryOptions{ + Statement: "SELECT 1=1", + }, }) require.NoError(t, err) @@ -101,7 +104,9 @@ func TestBucketsTrackingAgentManagerClose(t *testing.T) { require.NoError(t, err) _, err = agent.Query(context.Background(), &gocbcorex.QueryOptions{ - Statement: "SELECT 1=1", + QueryOptions: cbqueryx.QueryOptions{ + Statement: "SELECT 1=1", + }, }) require.NoError(t, err) diff --git a/basehttpcomponent.go b/basehttpcomponent.go index ccee4850..ca2f96da 100644 --- a/basehttpcomponent.go +++ b/basehttpcomponent.go @@ -2,6 +2,7 @@ package gocbcorex import ( "context" + "errors" "math/rand" "net/http" "sync" @@ -76,6 +77,37 @@ func (c *baseHttpComponent) GetAllTargets(endpointIdsToIgnore []string) (http.Ro return state.httpRoundTripper, targets, nil } +func (c *baseHttpComponent) SelectSpecificEndpoint(endpointId string) (http.RoundTripper, string, string, string, error) { + c.lock.RLock() + state := *c.state + c.lock.RUnlock() + + foundEndpoint := "" + for _, endpoint := range state.endpoints { + if endpoint == endpointId { + foundEndpoint = endpoint + break + } + } + + if foundEndpoint == "" { + return nil, "", "", "", errors.New("invalid endpoint") + } + + host, err := getHostFromUri(foundEndpoint) + if err != nil { + return nil, "", "", "", err + } + + username, password, err := state.authenticator.GetCredentials(c.serviceType, host) + if err != nil { + return nil, "", "", "", err + } + + return state.httpRoundTripper, foundEndpoint, username, password, nil +} + + func (c *baseHttpComponent) SelectEndpoint(endpointIdsToIgnore []string) (http.RoundTripper, string, string, string, string, error) { c.lock.RLock() state := *c.state diff --git a/querycomponent.go b/querycomponent.go index 535c9eab..89b849fa 100644 --- a/querycomponent.go +++ b/querycomponent.go @@ -11,8 +11,16 @@ import ( "go.uber.org/zap" ) -type QueryOptions = cbqueryx.QueryOptions -type QueryResultStream = cbqueryx.ResultStream +type QueryOptions struct { + cbqueryx.QueryOptions + Endpoint string +} + +type QueryResultStream interface { + cbqueryx.ResultStream + Endpoint() string +} + type PreparedStatementCache = cbqueryx.PreparedStatementCache type QueryComponent struct { @@ -37,9 +45,20 @@ type QueryComponentOptions struct { func OrchestrateQueryEndpoint[RespT any]( ctx context.Context, w *QueryComponent, - fn func(roundTripper http.RoundTripper, endpoint, username, password string) (RespT, error), + endpointId string, + fn func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (RespT, error), ) (RespT, error) { - roundTripper, _, endpoint, username, password, err := w.SelectEndpoint(nil) + if endpointId != "" { + roundTripper, endpoint, username, password, err := w.SelectSpecificEndpoint(endpointId) + if err != nil { + var emptyResp RespT + return emptyResp, err + } + + return fn(roundTripper, endpointId, endpoint, username, password) + } + + roundTripper, endpointId, endpoint, username, password, err := w.SelectEndpoint(nil) if err != nil { var emptyResp RespT return emptyResp, err @@ -50,7 +69,7 @@ func OrchestrateQueryEndpoint[RespT any]( return emptyResp, serviceNotAvailableError{Service: ServiceTypeQuery} } - return fn(roundTripper, endpoint, username, password) + return fn(roundTripper, endpointId, endpoint, username, password) } func OrchestrateQueryMgmtCall[OptsT any, RespT any]( @@ -60,8 +79,8 @@ func OrchestrateQueryMgmtCall[OptsT any, RespT any]( opts OptsT, ) (RespT, error) { return OrchestrateRetries(ctx, w.retries, func() (RespT, error) { - return OrchestrateQueryEndpoint(ctx, w, - func(roundTripper http.RoundTripper, endpoint, username, password string) (RespT, error) { + return OrchestrateQueryEndpoint(ctx, w, "", + func(roundTripper http.RoundTripper, _, endpoint, username, password string) (RespT, error) { return execFn(cbqueryx.Query{ Logger: w.logger, UserAgent: w.userAgent, @@ -81,8 +100,8 @@ func OrchestrateNoResQueryMgmtCall[OptsT any]( opts OptsT, ) error { return OrchestrateNoResponseRetries(ctx, w.retries, func() error { - _, err := OrchestrateQueryEndpoint(ctx, w, - func(roundTripper http.RoundTripper, endpoint, username, password string) (interface{}, error) { + _, err := OrchestrateQueryEndpoint(ctx, w, "", + func(roundTripper http.RoundTripper, _, endpoint, username, password string) (interface{}, error) { return nil, execFn(cbqueryx.Query{ Logger: w.logger, UserAgent: w.userAgent, @@ -122,27 +141,44 @@ func (w *QueryComponent) Reconfigure(config *QueryComponentConfig) error { return nil } +type queryResultStream struct { + cbqueryx.ResultStream + endpoint string +} + +func (s *queryResultStream) Endpoint() string { + return s.endpoint +} + func (w *QueryComponent) Query(ctx context.Context, opts *QueryOptions) (QueryResultStream, error) { return OrchestrateRetries(ctx, w.retries, func() (QueryResultStream, error) { - return OrchestrateQueryEndpoint(ctx, w, - func(roundTripper http.RoundTripper, endpoint, username, password string) (QueryResultStream, error) { - return cbqueryx.Query{ + return OrchestrateQueryEndpoint(ctx, w, opts.Endpoint, + func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (QueryResultStream, error) { + res, err := cbqueryx.Query{ Logger: w.logger, UserAgent: w.userAgent, Transport: roundTripper, Endpoint: endpoint, Username: username, Password: password, - }.Query(ctx, opts) + }.Query(ctx, &opts.QueryOptions) + if err != nil { + return nil, err + } + + return &queryResultStream{ + ResultStream: res, + endpoint: endpointId, + }, nil }) }) } func (w *QueryComponent) PreparedQuery(ctx context.Context, opts *QueryOptions) (QueryResultStream, error) { return OrchestrateRetries(ctx, w.retries, func() (QueryResultStream, error) { - return OrchestrateQueryEndpoint(ctx, w, - func(roundTripper http.RoundTripper, endpoint, username, password string) (QueryResultStream, error) { - return cbqueryx.PreparedQuery{ + return OrchestrateQueryEndpoint(ctx, w, opts.Endpoint, + func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (QueryResultStream, error) { + res, err := cbqueryx.PreparedQuery{ Executor: cbqueryx.Query{ Logger: w.logger, UserAgent: w.userAgent, @@ -152,7 +188,15 @@ func (w *QueryComponent) PreparedQuery(ctx context.Context, opts *QueryOptions) Password: password, }, Cache: w.preparedCache, - }.PreparedQuery(ctx, opts) + }.PreparedQuery(ctx, &opts.QueryOptions) + if err != nil { + return nil, err + } + + return &queryResultStream{ + ResultStream: res, + endpoint: endpointId, + }, nil }) }) } diff --git a/querycomponent_int_test.go b/querycomponent_int_test.go index 77469f69..ea2ffee0 100644 --- a/querycomponent_int_test.go +++ b/querycomponent_int_test.go @@ -58,8 +58,10 @@ func (nqh *n1qlTestHelper) testN1QLBasic(t *testing.T) { defer cancel() rows, err := nqh.QueryFn(ctx, &gocbcorex.QueryOptions{ - ClientContextId: "12345", - Statement: fmt.Sprintf("SELECT i,testName FROM %s WHERE testName=\"%s\"", testutilsint.TestOpts.BucketName, nqh.TestName), + QueryOptions: cbqueryx.QueryOptions{ + ClientContextId: "12345", + Statement: fmt.Sprintf("SELECT i,testName FROM %s WHERE testName=\"%s\"", testutilsint.TestOpts.BucketName, nqh.TestName), + }, }) if err != nil { nqh.T.Logf("Received error from query: %v", err) @@ -514,3 +516,44 @@ func TestQueryMgmtDeferredIndex(t *testing.T) { return false }, 30*time.Second, 500*time.Millisecond) } + +// TestQueryNodePinning tests that the same node is used for multiple queries when the endpoint is pinned. +// We do this by performing one query to get an endpoint ID, then do another 10 queries and ensure that the +// same endpoint keeps being used. +func TestQueryNodePinning(t *testing.T) { + testutilsint.SkipIfShortTest(t) + + agent := CreateDefaultAgent(t) + t.Cleanup(func() { + err := agent.Close() + require.NoError(t, err) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + res, err := agent.Query(ctx, &gocbcorex.QueryOptions{ + QueryOptions: cbqueryx.QueryOptions{ + Statement: "SELECT 1=1", + }, + Endpoint: "", + }) + require.NoError(t, err) + + firstEndpoint := res.Endpoint() + require.NotEmpty(t, firstEndpoint) + require.Regexp(t, `^quep-(.*)`, firstEndpoint) + + for i := 0; i < 10; i++ { + res, err := agent.Query(ctx, &gocbcorex.QueryOptions{ + QueryOptions: cbqueryx.QueryOptions{ + Statement: "SELECT 1=1", + }, + Endpoint: firstEndpoint, + }) + require.NoError(t, err) + + endpoint := res.Endpoint() + require.Equal(t, firstEndpoint, endpoint) + } +}