Skip to content

Commit

Permalink
ING-827: Added support for pinning queries to an endpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
Brett Lawson committed Jul 9, 2024
1 parent 1c0a202 commit 2201eb2
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 17 deletions.
31 changes: 31 additions & 0 deletions basehttpcomponent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gocbcorex

import (
"context"
"errors"
"math/rand"
"net/http"
"sync"
Expand Down Expand Up @@ -76,6 +77,36 @@ func (c *baseHttpComponent) GetAllTargets(endpointIdsToIgnore []string) (http.Ro
return state.httpRoundTripper, targets, nil
}

func (c *baseHttpComponent) SelectSpecificEndpoint(endpointId string) (http.RoundTripper, string, string, string, error) {
c.lock.RLock()
state := *c.state
c.lock.RUnlock()

foundEndpoint := ""
for epId, endpoint := range state.endpoints {
if epId == endpointId {
foundEndpoint = endpoint
break
}
}

if foundEndpoint == "" {
return nil, "", "", "", errors.New("invalid endpoint")
}

host, err := getHostFromUri(foundEndpoint)
if err != nil {
return nil, "", "", "", err
}

username, password, err := state.authenticator.GetCredentials(c.serviceType, host)
if err != nil {
return nil, "", "", "", err
}

return state.httpRoundTripper, foundEndpoint, username, password, nil
}

func (c *baseHttpComponent) SelectEndpoint(endpointIdsToIgnore []string) (http.RoundTripper, string, string, string, string, error) {
c.lock.RLock()
state := *c.state
Expand Down
211 changes: 194 additions & 17 deletions querycomponent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gocbcorex

import (
"context"
"encoding/json"
"net/http"
"time"

Expand All @@ -11,8 +12,60 @@ import (
"go.uber.org/zap"
)

type QueryOptions = cbqueryx.QueryOptions
type QueryResultStream = cbqueryx.ResultStream
type QueryOptions struct {
Args []json.RawMessage
AtrCollection string
AutoExecute bool
ClientContextId string
Compression cbqueryx.Compression
Controls bool
Creds []cbqueryx.CredsJson
DurabilityLevel cbqueryx.DurabilityLevel
EncodedPlan string
Encoding cbqueryx.Encoding
Format cbqueryx.Format
KvTimeout time.Duration
MaxParallelism uint32
MemoryQuota uint32
Metrics bool
Namespace string
NumAtrs uint32
PipelineBatch uint32
PipelineCap uint32
Prepared string
PreserveExpiry bool
Pretty bool
Profile cbqueryx.ProfileMode
QueryContext string
ReadOnly bool
ScanCap uint32
ScanConsistency cbqueryx.ScanConsistency
ScanVector json.RawMessage
ScanVectors map[string]json.RawMessage
ScanWait time.Duration
Signature bool
Statement string
Timeout time.Duration
TxData json.RawMessage
TxId string
TxImplicit bool
TxStmtNum uint32
TxTimeout time.Duration
UseCbo bool
UseFts bool

NamedArgs map[string]json.RawMessage
Raw map[string]json.RawMessage

OnBehalfOf *cbhttpx.OnBehalfOfInfo
Endpoint string
}

type QueryResultStream interface {
cbqueryx.ResultStream
Endpoint() string
}

type PreparedStatementCache = cbqueryx.PreparedStatementCache

type QueryComponent struct {
Expand All @@ -37,9 +90,20 @@ type QueryComponentOptions struct {
func OrchestrateQueryEndpoint[RespT any](
ctx context.Context,
w *QueryComponent,
fn func(roundTripper http.RoundTripper, endpoint, username, password string) (RespT, error),
endpointId string,
fn func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (RespT, error),
) (RespT, error) {
roundTripper, _, endpoint, username, password, err := w.SelectEndpoint(nil)
if endpointId != "" {
roundTripper, endpoint, username, password, err := w.SelectSpecificEndpoint(endpointId)
if err != nil {
var emptyResp RespT
return emptyResp, err
}

return fn(roundTripper, endpointId, endpoint, username, password)
}

roundTripper, endpointId, endpoint, username, password, err := w.SelectEndpoint(nil)
if err != nil {
var emptyResp RespT
return emptyResp, err
Expand All @@ -50,7 +114,7 @@ func OrchestrateQueryEndpoint[RespT any](
return emptyResp, serviceNotAvailableError{Service: ServiceTypeQuery}
}

return fn(roundTripper, endpoint, username, password)
return fn(roundTripper, endpointId, endpoint, username, password)
}

func OrchestrateQueryMgmtCall[OptsT any, RespT any](
Expand All @@ -60,8 +124,8 @@ func OrchestrateQueryMgmtCall[OptsT any, RespT any](
opts OptsT,
) (RespT, error) {
return OrchestrateRetries(ctx, w.retries, func() (RespT, error) {
return OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (RespT, error) {
return OrchestrateQueryEndpoint(ctx, w, "",
func(roundTripper http.RoundTripper, _, endpoint, username, password string) (RespT, error) {
return execFn(cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Expand All @@ -81,8 +145,8 @@ func OrchestrateNoResQueryMgmtCall[OptsT any](
opts OptsT,
) error {
return OrchestrateNoResponseRetries(ctx, w.retries, func() error {
_, err := OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (interface{}, error) {
_, err := OrchestrateQueryEndpoint(ctx, w, "",
func(roundTripper http.RoundTripper, _, endpoint, username, password string) (interface{}, error) {
return nil, execFn(cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Expand Down Expand Up @@ -122,27 +186,88 @@ func (w *QueryComponent) Reconfigure(config *QueryComponentConfig) error {
return nil
}

type queryResultStream struct {
cbqueryx.ResultStream
endpoint string
}

func (s *queryResultStream) Endpoint() string {
return s.endpoint
}

func (w *QueryComponent) Query(ctx context.Context, opts *QueryOptions) (QueryResultStream, error) {
return OrchestrateRetries(ctx, w.retries, func() (QueryResultStream, error) {
return OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (QueryResultStream, error) {
return cbqueryx.Query{
return OrchestrateQueryEndpoint(ctx, w, opts.Endpoint,
func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (QueryResultStream, error) {
res, err := cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Transport: roundTripper,
Endpoint: endpoint,
Username: username,
Password: password,
}.Query(ctx, opts)
}.Query(ctx, &cbqueryx.QueryOptions{
Args: opts.Args,
AtrCollection: opts.AtrCollection,
AutoExecute: opts.AutoExecute,
ClientContextId: opts.ClientContextId,
Compression: opts.Compression,
Controls: opts.Controls,
Creds: opts.Creds,
DurabilityLevel: opts.DurabilityLevel,
EncodedPlan: opts.EncodedPlan,
Encoding: opts.Encoding,
Format: opts.Format,
KvTimeout: opts.KvTimeout,
MaxParallelism: opts.MaxParallelism,
MemoryQuota: opts.MemoryQuota,
Metrics: opts.Metrics,
Namespace: opts.Namespace,
NumAtrs: opts.NumAtrs,
PipelineBatch: opts.PipelineBatch,
PipelineCap: opts.PipelineCap,
Prepared: opts.Prepared,
PreserveExpiry: opts.PreserveExpiry,
Pretty: opts.Pretty,
Profile: opts.Profile,
QueryContext: opts.QueryContext,
ReadOnly: opts.ReadOnly,
ScanCap: opts.ScanCap,
ScanConsistency: opts.ScanConsistency,
ScanVector: opts.ScanVector,
ScanVectors: opts.ScanVectors,
ScanWait: opts.ScanWait,
Signature: opts.Signature,
Statement: opts.Statement,
Timeout: opts.Timeout,
TxData: opts.TxData,
TxId: opts.TxId,
TxImplicit: opts.TxImplicit,
TxStmtNum: opts.TxStmtNum,
TxTimeout: opts.TxTimeout,
UseCbo: opts.UseCbo,
UseFts: opts.UseFts,
NamedArgs: opts.NamedArgs,
Raw: opts.Raw,
OnBehalfOf: opts.OnBehalfOf,
})
if err != nil {
return nil, err
}

return &queryResultStream{
ResultStream: res,
endpoint: endpointId,
}, nil
})
})
}

func (w *QueryComponent) PreparedQuery(ctx context.Context, opts *QueryOptions) (QueryResultStream, error) {
return OrchestrateRetries(ctx, w.retries, func() (QueryResultStream, error) {
return OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (QueryResultStream, error) {
return cbqueryx.PreparedQuery{
return OrchestrateQueryEndpoint(ctx, w, opts.Endpoint,
func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (QueryResultStream, error) {
res, err := cbqueryx.PreparedQuery{
Executor: cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Expand All @@ -152,7 +277,59 @@ func (w *QueryComponent) PreparedQuery(ctx context.Context, opts *QueryOptions)
Password: password,
},
Cache: w.preparedCache,
}.PreparedQuery(ctx, opts)
}.PreparedQuery(ctx, &cbqueryx.QueryOptions{
Args: opts.Args,
AtrCollection: opts.AtrCollection,
AutoExecute: opts.AutoExecute,
ClientContextId: opts.ClientContextId,
Compression: opts.Compression,
Controls: opts.Controls,
Creds: opts.Creds,
DurabilityLevel: opts.DurabilityLevel,
EncodedPlan: opts.EncodedPlan,
Encoding: opts.Encoding,
Format: opts.Format,
KvTimeout: opts.KvTimeout,
MaxParallelism: opts.MaxParallelism,
MemoryQuota: opts.MemoryQuota,
Metrics: opts.Metrics,
Namespace: opts.Namespace,
NumAtrs: opts.NumAtrs,
PipelineBatch: opts.PipelineBatch,
PipelineCap: opts.PipelineCap,
Prepared: opts.Prepared,
PreserveExpiry: opts.PreserveExpiry,
Pretty: opts.Pretty,
Profile: opts.Profile,
QueryContext: opts.QueryContext,
ReadOnly: opts.ReadOnly,
ScanCap: opts.ScanCap,
ScanConsistency: opts.ScanConsistency,
ScanVector: opts.ScanVector,
ScanVectors: opts.ScanVectors,
ScanWait: opts.ScanWait,
Signature: opts.Signature,
Statement: opts.Statement,
Timeout: opts.Timeout,
TxData: opts.TxData,
TxId: opts.TxId,
TxImplicit: opts.TxImplicit,
TxStmtNum: opts.TxStmtNum,
TxTimeout: opts.TxTimeout,
UseCbo: opts.UseCbo,
UseFts: opts.UseFts,
NamedArgs: opts.NamedArgs,
Raw: opts.Raw,
OnBehalfOf: opts.OnBehalfOf,
})
if err != nil {
return nil, err
}

return &queryResultStream{
ResultStream: res,
endpoint: endpointId,
}, nil
})
})
}
Expand Down
37 changes: 37 additions & 0 deletions querycomponent_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,40 @@ func TestQueryMgmtDeferredIndex(t *testing.T) {
return false
}, 30*time.Second, 500*time.Millisecond)
}

// TestQueryNodePinning tests that the same node is used for multiple queries when the endpoint is pinned.
// We do this by performing one query to get an endpoint ID, then do another 10 queries and ensure that the
// same endpoint keeps being used.
func TestQueryNodePinning(t *testing.T) {
testutilsint.SkipIfShortTest(t)

agent := CreateDefaultAgent(t)
t.Cleanup(func() {
err := agent.Close()
require.NoError(t, err)
})

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

res, err := agent.Query(ctx, &gocbcorex.QueryOptions{
Statement: "SELECT 1=1",
Endpoint: "",
})
require.NoError(t, err)

firstEndpoint := res.Endpoint()
require.NotEmpty(t, firstEndpoint)
require.Regexp(t, `^quep-(.*)`, firstEndpoint)

for i := 0; i < 10; i++ {
res, err := agent.Query(ctx, &gocbcorex.QueryOptions{
Statement: "SELECT 1=1",
Endpoint: firstEndpoint,
})
require.NoError(t, err)

endpoint := res.Endpoint()
require.Equal(t, firstEndpoint, endpoint)
}
}

0 comments on commit 2201eb2

Please sign in to comment.