Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ING-827: Added support for pinning queries to an endpoint. #283

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions basehttpcomponent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gocbcorex

import (
"context"
"errors"
"math/rand"
"net/http"
"sync"
Expand Down Expand Up @@ -76,6 +77,36 @@ func (c *baseHttpComponent) GetAllTargets(endpointIdsToIgnore []string) (http.Ro
return state.httpRoundTripper, targets, nil
}

func (c *baseHttpComponent) SelectSpecificEndpoint(endpointId string) (http.RoundTripper, string, string, string, error) {
c.lock.RLock()
state := *c.state
c.lock.RUnlock()

foundEndpoint := ""
for epId, endpoint := range state.endpoints {
if epId == endpointId {
foundEndpoint = endpoint
break
}
}

if foundEndpoint == "" {
return nil, "", "", "", errors.New("invalid endpoint")
}

host, err := getHostFromUri(foundEndpoint)
if err != nil {
return nil, "", "", "", err
}

username, password, err := state.authenticator.GetCredentials(c.serviceType, host)
if err != nil {
return nil, "", "", "", err
}

return state.httpRoundTripper, foundEndpoint, username, password, nil
}

func (c *baseHttpComponent) SelectEndpoint(endpointIdsToIgnore []string) (http.RoundTripper, string, string, string, string, error) {
c.lock.RLock()
state := *c.state
Expand Down
211 changes: 194 additions & 17 deletions querycomponent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gocbcorex

import (
"context"
"encoding/json"
"net/http"
"time"

Expand All @@ -11,8 +12,60 @@ import (
"go.uber.org/zap"
)

type QueryOptions = cbqueryx.QueryOptions
type QueryResultStream = cbqueryx.ResultStream
type QueryOptions struct {
Args []json.RawMessage
AtrCollection string
AutoExecute bool
ClientContextId string
Compression cbqueryx.Compression
Controls bool
Creds []cbqueryx.CredsJson
DurabilityLevel cbqueryx.DurabilityLevel
EncodedPlan string
Encoding cbqueryx.Encoding
Format cbqueryx.Format
KvTimeout time.Duration
MaxParallelism uint32
MemoryQuota uint32
Metrics bool
Namespace string
NumAtrs uint32
PipelineBatch uint32
PipelineCap uint32
Prepared string
PreserveExpiry bool
Pretty bool
Profile cbqueryx.ProfileMode
QueryContext string
ReadOnly bool
ScanCap uint32
ScanConsistency cbqueryx.ScanConsistency
ScanVector json.RawMessage
ScanVectors map[string]json.RawMessage
ScanWait time.Duration
Signature bool
Statement string
Timeout time.Duration
TxData json.RawMessage
TxId string
TxImplicit bool
TxStmtNum uint32
TxTimeout time.Duration
UseCbo bool
UseFts bool

NamedArgs map[string]json.RawMessage
Raw map[string]json.RawMessage

OnBehalfOf *cbhttpx.OnBehalfOfInfo
Endpoint string
}

type QueryResultStream interface {
cbqueryx.ResultStream
Endpoint() string
}

type PreparedStatementCache = cbqueryx.PreparedStatementCache

type QueryComponent struct {
Expand All @@ -37,9 +90,20 @@ type QueryComponentOptions struct {
func OrchestrateQueryEndpoint[RespT any](
ctx context.Context,
w *QueryComponent,
fn func(roundTripper http.RoundTripper, endpoint, username, password string) (RespT, error),
endpointId string,
fn func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (RespT, error),
) (RespT, error) {
roundTripper, _, endpoint, username, password, err := w.SelectEndpoint(nil)
if endpointId != "" {
roundTripper, endpoint, username, password, err := w.SelectSpecificEndpoint(endpointId)
if err != nil {
var emptyResp RespT
return emptyResp, err
}

return fn(roundTripper, endpointId, endpoint, username, password)
}

roundTripper, endpointId, endpoint, username, password, err := w.SelectEndpoint(nil)
if err != nil {
var emptyResp RespT
return emptyResp, err
Expand All @@ -50,7 +114,7 @@ func OrchestrateQueryEndpoint[RespT any](
return emptyResp, serviceNotAvailableError{Service: ServiceTypeQuery}
}

return fn(roundTripper, endpoint, username, password)
return fn(roundTripper, endpointId, endpoint, username, password)
}

func OrchestrateQueryMgmtCall[OptsT any, RespT any](
Expand All @@ -60,8 +124,8 @@ func OrchestrateQueryMgmtCall[OptsT any, RespT any](
opts OptsT,
) (RespT, error) {
return OrchestrateRetries(ctx, w.retries, func() (RespT, error) {
return OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (RespT, error) {
return OrchestrateQueryEndpoint(ctx, w, "",
func(roundTripper http.RoundTripper, _, endpoint, username, password string) (RespT, error) {
return execFn(cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Expand All @@ -81,8 +145,8 @@ func OrchestrateNoResQueryMgmtCall[OptsT any](
opts OptsT,
) error {
return OrchestrateNoResponseRetries(ctx, w.retries, func() error {
_, err := OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (interface{}, error) {
_, err := OrchestrateQueryEndpoint(ctx, w, "",
func(roundTripper http.RoundTripper, _, endpoint, username, password string) (interface{}, error) {
return nil, execFn(cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Expand Down Expand Up @@ -122,27 +186,88 @@ func (w *QueryComponent) Reconfigure(config *QueryComponentConfig) error {
return nil
}

type queryResultStream struct {
cbqueryx.ResultStream
endpoint string
}

func (s *queryResultStream) Endpoint() string {
return s.endpoint
}

func (w *QueryComponent) Query(ctx context.Context, opts *QueryOptions) (QueryResultStream, error) {
return OrchestrateRetries(ctx, w.retries, func() (QueryResultStream, error) {
return OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (QueryResultStream, error) {
return cbqueryx.Query{
return OrchestrateQueryEndpoint(ctx, w, opts.Endpoint,
func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (QueryResultStream, error) {
res, err := cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Transport: roundTripper,
Endpoint: endpoint,
Username: username,
Password: password,
}.Query(ctx, opts)
}.Query(ctx, &cbqueryx.QueryOptions{
Args: opts.Args,
AtrCollection: opts.AtrCollection,
AutoExecute: opts.AutoExecute,
ClientContextId: opts.ClientContextId,
Compression: opts.Compression,
Controls: opts.Controls,
Creds: opts.Creds,
DurabilityLevel: opts.DurabilityLevel,
EncodedPlan: opts.EncodedPlan,
Encoding: opts.Encoding,
Format: opts.Format,
KvTimeout: opts.KvTimeout,
MaxParallelism: opts.MaxParallelism,
MemoryQuota: opts.MemoryQuota,
Metrics: opts.Metrics,
Namespace: opts.Namespace,
NumAtrs: opts.NumAtrs,
PipelineBatch: opts.PipelineBatch,
PipelineCap: opts.PipelineCap,
Prepared: opts.Prepared,
PreserveExpiry: opts.PreserveExpiry,
Pretty: opts.Pretty,
Profile: opts.Profile,
QueryContext: opts.QueryContext,
ReadOnly: opts.ReadOnly,
ScanCap: opts.ScanCap,
ScanConsistency: opts.ScanConsistency,
ScanVector: opts.ScanVector,
ScanVectors: opts.ScanVectors,
ScanWait: opts.ScanWait,
Signature: opts.Signature,
Statement: opts.Statement,
Timeout: opts.Timeout,
TxData: opts.TxData,
TxId: opts.TxId,
TxImplicit: opts.TxImplicit,
TxStmtNum: opts.TxStmtNum,
TxTimeout: opts.TxTimeout,
UseCbo: opts.UseCbo,
UseFts: opts.UseFts,
NamedArgs: opts.NamedArgs,
Raw: opts.Raw,
OnBehalfOf: opts.OnBehalfOf,
})
if err != nil {
return nil, err
}

return &queryResultStream{
ResultStream: res,
endpoint: endpointId,
}, nil
})
})
}

func (w *QueryComponent) PreparedQuery(ctx context.Context, opts *QueryOptions) (QueryResultStream, error) {
return OrchestrateRetries(ctx, w.retries, func() (QueryResultStream, error) {
return OrchestrateQueryEndpoint(ctx, w,
func(roundTripper http.RoundTripper, endpoint, username, password string) (QueryResultStream, error) {
return cbqueryx.PreparedQuery{
return OrchestrateQueryEndpoint(ctx, w, opts.Endpoint,
func(roundTripper http.RoundTripper, endpointId, endpoint, username, password string) (QueryResultStream, error) {
res, err := cbqueryx.PreparedQuery{
Executor: cbqueryx.Query{
Logger: w.logger,
UserAgent: w.userAgent,
Expand All @@ -152,7 +277,59 @@ func (w *QueryComponent) PreparedQuery(ctx context.Context, opts *QueryOptions)
Password: password,
},
Cache: w.preparedCache,
}.PreparedQuery(ctx, opts)
}.PreparedQuery(ctx, &cbqueryx.QueryOptions{
Args: opts.Args,
AtrCollection: opts.AtrCollection,
AutoExecute: opts.AutoExecute,
ClientContextId: opts.ClientContextId,
Compression: opts.Compression,
Controls: opts.Controls,
Creds: opts.Creds,
DurabilityLevel: opts.DurabilityLevel,
EncodedPlan: opts.EncodedPlan,
Encoding: opts.Encoding,
Format: opts.Format,
KvTimeout: opts.KvTimeout,
MaxParallelism: opts.MaxParallelism,
MemoryQuota: opts.MemoryQuota,
Metrics: opts.Metrics,
Namespace: opts.Namespace,
NumAtrs: opts.NumAtrs,
PipelineBatch: opts.PipelineBatch,
PipelineCap: opts.PipelineCap,
Prepared: opts.Prepared,
PreserveExpiry: opts.PreserveExpiry,
Pretty: opts.Pretty,
Profile: opts.Profile,
QueryContext: opts.QueryContext,
ReadOnly: opts.ReadOnly,
ScanCap: opts.ScanCap,
ScanConsistency: opts.ScanConsistency,
ScanVector: opts.ScanVector,
ScanVectors: opts.ScanVectors,
ScanWait: opts.ScanWait,
Signature: opts.Signature,
Statement: opts.Statement,
Timeout: opts.Timeout,
TxData: opts.TxData,
TxId: opts.TxId,
TxImplicit: opts.TxImplicit,
TxStmtNum: opts.TxStmtNum,
TxTimeout: opts.TxTimeout,
UseCbo: opts.UseCbo,
UseFts: opts.UseFts,
NamedArgs: opts.NamedArgs,
Raw: opts.Raw,
OnBehalfOf: opts.OnBehalfOf,
})
if err != nil {
return nil, err
}

return &queryResultStream{
ResultStream: res,
endpoint: endpointId,
}, nil
})
})
}
Expand Down
37 changes: 37 additions & 0 deletions querycomponent_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,40 @@ func TestQueryMgmtDeferredIndex(t *testing.T) {
return false
}, 30*time.Second, 500*time.Millisecond)
}

// TestQueryNodePinning tests that the same node is used for multiple queries when the endpoint is pinned.
// We do this by performing one query to get an endpoint ID, then do another 10 queries and ensure that the
// same endpoint keeps being used.
func TestQueryNodePinning(t *testing.T) {
testutilsint.SkipIfShortTest(t)

agent := CreateDefaultAgent(t)
t.Cleanup(func() {
err := agent.Close()
require.NoError(t, err)
})

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

res, err := agent.Query(ctx, &gocbcorex.QueryOptions{
Statement: "SELECT 1=1",
Endpoint: "",
})
require.NoError(t, err)

firstEndpoint := res.Endpoint()
require.NotEmpty(t, firstEndpoint)
require.Regexp(t, `^quep-(.*)`, firstEndpoint)

for i := 0; i < 10; i++ {
res, err := agent.Query(ctx, &gocbcorex.QueryOptions{
Statement: "SELECT 1=1",
Endpoint: firstEndpoint,
})
require.NoError(t, err)

endpoint := res.Endpoint()
require.Equal(t, firstEndpoint, endpoint)
}
}
Loading