diff --git a/BUILD.bazel b/BUILD.bazel index 2447fc7..4cf9bd4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -35,8 +35,10 @@ go_library( "//src/api/proto/cloudpb:cloudapi_pl_go_proto", "//src/api/proto/vizierpb:vizier_pl_go_proto", "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//codes", "@org_golang_google_grpc//credentials", "@org_golang_google_grpc//metadata", + "@org_golang_google_grpc//status", ], ) diff --git a/results.go b/results.go index eb0e642..9d3a451 100644 --- a/results.go +++ b/results.go @@ -20,10 +20,16 @@ package pxapi import ( "context" + "errors" + "fmt" "io" + "strings" "sync" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "px.dev/pxapi/errdefs" "px.dev/pxapi/types" "px.dev/pxapi/utils" @@ -59,6 +65,10 @@ type ScriptResults struct { wg sync.WaitGroup stats *ResultsStats + + v *VizierClient + queryID string + origCtx context.Context } func newScriptResults() *ScriptResults { @@ -133,6 +143,37 @@ func (s *ScriptResults) handleGRPCMsg(ctx context.Context, resp *vizierpb.Execut return errdefs.ErrInternalUnImplementedType } +func isTransientGRPCError(err error) bool { + s, ok := status.FromError(err) + if !ok { + return false + } + if s.Code() == codes.Internal && strings.Contains(s.Message(), "RST_STREAM") { + return true + } + return false +} + +func (s *ScriptResults) reconnect() error { + if s.queryID == "" { + return errors.New("cannot reconnect to query that hasn't returned a QueryID yet") + } + req := &vizierpb.ExecuteScriptRequest{ + ClusterID: s.v.vizierID, + QueryID: s.queryID, + EncryptionOptions: s.v.encOpts, + } + ctx, cancel := context.WithCancel(s.origCtx) + res, err := s.v.vzClient.ExecuteScript(s.v.cloud.cloudCtxWithMD(ctx), req) + if err != nil { + cancel() + return err + } + s.cancel = cancel + s.c = res + return nil +} + func (s *ScriptResults) run() error { ctx := s.c.Context() for { @@ -143,11 +184,23 @@ func (s *ScriptResults) run() error { // Stream has terminated. return nil } + if isTransientGRPCError(err) { + origErr := err + err = s.reconnect() + if err != nil { + return fmt.Errorf("streaming failed: %w, error occurred while reconnecting: %v", origErr, err) + } + ctx = s.c.Context() + continue + } return err } if resp == nil { return nil } + if s.queryID == "" { + s.queryID = resp.QueryID + } if err := s.handleGRPCMsg(ctx, resp); err != nil { return err } diff --git a/vizier.go b/vizier.go index 1a2a459..7ff434a 100644 --- a/vizier.go +++ b/vizier.go @@ -42,6 +42,7 @@ func (v *VizierClient) ExecuteScript(ctx context.Context, pxl string, mux TableM QueryStr: pxl, EncryptionOptions: v.encOpts, } + origCtx := ctx ctx, cancel := context.WithCancel(ctx) res, err := v.vzClient.ExecuteScript(v.cloud.cloudCtxWithMD(ctx), req) if err != nil { @@ -54,6 +55,8 @@ func (v *VizierClient) ExecuteScript(ctx context.Context, pxl string, mux TableM sr.cancel = cancel sr.tm = mux sr.decOpts = v.decOpts + sr.v = v + sr.origCtx = origCtx return sr, nil }