// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package agent

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"slices"
	"strconv"
	"strings"

	"github.com/golang/snappy"
	"github.com/gorilla/websocket"
	"github.com/hashicorp/go-msgpack/v2/codec"
	cstructs "github.com/hashicorp/nomad/client/structs"
	"github.com/hashicorp/nomad/helper/pointer"
	"github.com/hashicorp/nomad/nomad/structs"
	"github.com/hashicorp/nomad/plugins/drivers"
)

const (
	allocNotFoundErr    = "allocation not found"
	resourceNotFoundErr = "resource not found"
)

func (s *HTTPServer) AllocsRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	if req.Method != http.MethodGet {
		return nil, CodedError(405, ErrInvalidMethod)
	}

	args := structs.AllocListRequest{}
	if s.parse(resp, req, &args.Region, &args.QueryOptions) {
		return nil, nil
	}

	// Parse resources and task_states field selection
	resources, err := parseBool(req, "resources")
	if err != nil {
		return nil, err
	}
	taskStates, err := parseBool(req, "task_states")
	if err != nil {
		return nil, err
	}

	if resources != nil || taskStates != nil {
		args.Fields = structs.NewAllocStubFields()
		if resources != nil {
			args.Fields.Resources = *resources
		}
		if taskStates != nil {
			args.Fields.TaskStates = *taskStates
		}
	}

	var out structs.AllocListResponse
	if err := s.agent.RPC("Alloc.List", &args, &out); err != nil {
		return nil, err
	}

	setMeta(resp, &out.QueryMeta)
	if out.Allocations == nil {
		out.Allocations = make([]*structs.AllocListStub, 0)
	}
	for _, alloc := range out.Allocations {
		alloc.SetEventDisplayMessages()
	}
	return out.Allocations, nil
}

func (s *HTTPServer) AllocSpecificRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	reqSuffix := strings.TrimPrefix(req.URL.Path, "/v1/allocation/")

	// tokenize the suffix of the path to get the alloc id and find the action
	// invoked on the alloc id
	tokens := strings.Split(reqSuffix, "/")
	if len(tokens) > 2 || len(tokens) < 1 {
		return nil, CodedError(404, resourceNotFoundErr)
	}
	allocID := tokens[0]

	if len(tokens) == 1 {
		return s.allocGet(allocID, resp, req)
	}

	switch tokens[1] {
	case "checks":
		return s.allocChecks(allocID, resp, req)
	case "stop":
		return s.allocStop(allocID, resp, req)
	case "services":
		return s.allocServiceRegistrations(resp, req, allocID)
	}

	return nil, CodedError(404, resourceNotFoundErr)
}

func (s *HTTPServer) allocGet(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	if req.Method != http.MethodGet {
		return nil, CodedError(405, ErrInvalidMethod)
	}

	args := structs.AllocSpecificRequest{
		AllocID: allocID,
	}
	if s.parse(resp, req, &args.Region, &args.QueryOptions) {
		return nil, nil
	}

	var out structs.SingleAllocResponse
	if err := s.agent.RPC("Alloc.GetAlloc", &args, &out); err != nil {
		return nil, err
	}

	setMeta(resp, &out.QueryMeta)
	if out.Alloc == nil {
		return nil, CodedError(404, "alloc not found")
	}

	// Decode the payload if there is any

	alloc := out.Alloc
	if alloc.Job != nil && len(alloc.Job.Payload) != 0 {
		decoded, err := snappy.Decode(nil, alloc.Job.Payload)
		if err != nil {
			return nil, err
		}
		alloc = alloc.Copy()
		alloc.Job.Payload = decoded
	}
	alloc.SetEventDisplayMessages()

	// Handle 0.12 ports upgrade path
	alloc = alloc.Copy()
	alloc.AllocatedResources.Canonicalize()

	return alloc, nil
}

func (s *HTTPServer) allocStop(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	if !(req.Method == http.MethodPost || req.Method == http.MethodPut) {
		return nil, CodedError(405, ErrInvalidMethod)
	}

	noShutdownDelay, err := parseBool(req, "no_shutdown_delay")
	if err != nil {
		return nil, err
	} else if noShutdownDelay == nil {
		noShutdownDelay = pointer.Of(false)
	}

	reschedule, err := parseBool(req, "reschedule")
	if err != nil {
		return nil, err
	} else if reschedule == nil {
		reschedule = pointer.Of(false)
	}

	sr := &structs.AllocStopRequest{
		AllocID:         allocID,
		NoShutdownDelay: *noShutdownDelay,
		Reschedule:      *reschedule,
	}
	s.parseWriteRequest(req, &sr.WriteRequest)

	var out structs.AllocStopResponse
	rpcErr := s.agent.RPC("Alloc.Stop", &sr, &out)

	if rpcErr != nil {
		if structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, allocNotFoundErr)
		}
		return nil, rpcErr
	}

	setIndex(resp, out.Index)
	return &out, nil
}

// allocServiceRegistrations returns a list of all service registrations
// assigned to the job identifier. It is callable via the
// /v1/allocation/:alloc_id/services HTTP API and uses the
// structs.AllocServiceRegistrationsRPCMethod RPC method.
func (s *HTTPServer) allocServiceRegistrations(
	resp http.ResponseWriter, req *http.Request, allocID string) (interface{}, error) {

	// The endpoint only supports GET requests.
	if req.Method != http.MethodGet {
		return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod)
	}

	// Set up the request args and parse this to ensure the query options are
	// set.
	args := structs.AllocServiceRegistrationsRequest{AllocID: allocID}
	if s.parse(resp, req, &args.Region, &args.QueryOptions) {
		return nil, nil
	}

	// Perform the RPC request.
	var reply structs.AllocServiceRegistrationsResponse
	if err := s.agent.RPC(structs.AllocServiceRegistrationsRPCMethod, &args, &reply); err != nil {
		return nil, err
	}

	setMeta(resp, &reply.QueryMeta)

	if reply.Services == nil {
		return nil, CodedError(http.StatusNotFound, allocNotFoundErr)
	}
	return reply.Services, nil
}

func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	reqSuffix := strings.TrimPrefix(req.URL.Path, "/v1/client/allocation/")

	// tokenize the suffix of the path to get the alloc id and find the action
	// invoked on the alloc id
	tokens := strings.Split(reqSuffix, "/")
	if len(tokens) != 2 {
		return nil, CodedError(404, resourceNotFoundErr)
	}
	allocID := tokens[0]
	switch tokens[1] {
	case "checks":
		return s.allocChecks(allocID, resp, req)
	case "stats":
		return s.allocStats(allocID, resp, req)
	case "exec":
		return s.allocExec(allocID, resp, req)
	case "snapshot":
		if s.agent.Client() == nil {
			return nil, clientNotRunning
		}
		return s.allocSnapshot(allocID, resp, req)
	case "restart":
		return s.allocRestart(allocID, resp, req)
	case "gc":
		return s.allocGC(allocID, resp, req)
	case "signal":
		return s.allocSignal(allocID, resp, req)
	case "pause":
		return s.allocPause(allocID, resp, req)
	}

	return nil, CodedError(404, resourceNotFoundErr)
}

func (s *HTTPServer) ClientGCRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {

	// Build the request and get the requested Node ID
	args := structs.NodeSpecificRequest{}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
	parseNode(req, &args.NodeID)

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(args.NodeID)

	// Make the RPC
	var reply structs.GenericResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.GarbageCollectAll", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.GarbageCollectAll", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.GarbageCollectAll", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return nil, rpcErr
}

func (s *HTTPServer) allocRestart(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	// Build the request and parse the ACL token
	args := structs.AllocRestartRequest{
		AllocID:  allocID,
		TaskName: "",
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	// Explicitly parse the body separately to disallow overriding AllocID in req Body.
	var reqBody struct {
		TaskName string
		AllTasks bool
	}
	err := json.NewDecoder(req.Body).Decode(&reqBody)
	if err != nil && err != io.EOF {
		return nil, err
	}
	if reqBody.TaskName != "" {
		args.TaskName = reqBody.TaskName
	}
	if reqBody.AllTasks {
		args.AllTasks = reqBody.AllTasks
	}

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply structs.GenericResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.Restart", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.Restart", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.Restart", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return reply, rpcErr
}

func (s *HTTPServer) allocGC(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	// Build the request and parse the ACL token
	args := structs.AllocSpecificRequest{
		AllocID: allocID,
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply structs.GenericResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.GarbageCollect", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.GarbageCollect", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.GarbageCollect", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return nil, rpcErr
}

func (s *HTTPServer) allocSignal(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	if !(req.Method == http.MethodPost || req.Method == http.MethodPut) {
		return nil, CodedError(405, ErrInvalidMethod)
	}

	// Build the request and parse the ACL token
	args := structs.AllocSignalRequest{}
	err := decodeBody(req, &args)
	if err != nil {
		return nil, CodedError(400, fmt.Sprintf("Failed to decode body: %v", err))
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
	args.AllocID = allocID

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply structs.GenericResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.Signal", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.Signal", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.Signal", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return reply, rpcErr
}

func (s *HTTPServer) allocPause(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	switch req.Method {
	case http.MethodPost, http.MethodPut:
		return s.allocPauseSet(allocID, resp, req)
	case http.MethodGet:
		return s.allocPauseGet(allocID, resp, req)
	default:
		return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod)
	}
}

func (s *HTTPServer) allocPauseGet(allocID string, resp http.ResponseWriter, req *http.Request) (any, error) {
	// Build the request and parse the ACL token
	task := req.URL.Query().Get("task")
	args := structs.AllocGetPauseStateRequest{
		AllocID: allocID,
		Task:    task,
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply structs.AllocGetPauseStateResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.GetPauseState", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.GetPauseState", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.GetPauseState", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return reply, rpcErr
}

func (s *HTTPServer) allocPauseSet(allocID string, resp http.ResponseWriter, req *http.Request) (any, error) {
	// Build the request and parse the ACL token
	args := structs.AllocPauseRequest{
		AllocID: allocID,
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	// Explicitly parse the body separately to disallow overriding the allocID
	var reqBody struct {
		Task          string
		ScheduleState string
	}
	err := json.NewDecoder(req.Body).Decode(&reqBody)
	if err != nil && err != io.EOF {
		return nil, err
	}
	args.Task = reqBody.Task

	switch reqBody.ScheduleState {
	case "pause":
		args.ScheduleState = structs.TaskScheduleStateForcePause
	case "run":
		args.ScheduleState = structs.TaskScheduleStateForceRun
	case "scheduled":
		args.ScheduleState = structs.TaskScheduleStateSchedResume
	default:
		return nil, CodedError(400, "Not a valid task schedule state")
	}

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply structs.GenericResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.SetPauseState", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.SetPauseState", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.SetPauseState", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return reply, rpcErr
}

func (s *HTTPServer) allocSnapshot(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	var secret string
	s.parseToken(req, &secret)
	if !s.agent.Client().ValidateMigrateToken(allocID, secret) {
		return nil, structs.ErrPermissionDenied
	}

	allocFS, err := s.agent.Client().GetAllocFS(allocID)
	if err != nil {
		return nil, fmt.Errorf(allocNotFoundErr)
	}
	if err := allocFS.Snapshot(resp); err != nil {
		return nil, fmt.Errorf("error making snapshot: %v", err)
	}
	return nil, nil
}

func (s *HTTPServer) allocStats(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {

	// Build the request and parse the ACL token
	task := req.URL.Query().Get("task")
	args := cstructs.AllocStatsRequest{
		AllocID: allocID,
		Task:    task,
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply cstructs.AllocStatsResponse
	var rpcErr error
	if useLocalClient {
		rpcErr = s.agent.Client().ClientRPC("Allocations.Stats", &args, &reply)
	} else if useClientRPC {
		rpcErr = s.agent.Client().RPC("ClientAllocations.Stats", &args, &reply)
	} else if useServerRPC {
		rpcErr = s.agent.Server().RPC("ClientAllocations.Stats", &args, &reply)
	} else {
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return reply.Stats, rpcErr
}

func (s *HTTPServer) allocChecks(allocID string, resp http.ResponseWriter, req *http.Request) (any, error) {
	// Build the request and parse the ACL token
	args := cstructs.AllocChecksRequest{
		AllocID: allocID,
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	// Determine the handler to use
	useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)

	// Make the RPC
	var reply cstructs.AllocChecksResponse
	var rpcErr error
	switch {
	case useLocalClient:
		rpcErr = s.agent.Client().ClientRPC("Allocations.Checks", &args, &reply)
	case useClientRPC:
		rpcErr = s.agent.Client().RPC("ClientAllocations.Checks", &args, &reply)
	case useServerRPC:
		rpcErr = s.agent.Server().RPC("ClientAllocations.Checks", &args, &reply)
	default:
		rpcErr = CodedError(400, "No local Node and node_id not provided")
	}

	if rpcErr != nil {
		if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
			rpcErr = CodedError(404, rpcErr.Error())
		}
	}

	return reply.Results, rpcErr
}

func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
	// Build the request and parse the ACL token
	task := req.URL.Query().Get("task")
	cmdJsonStr := req.URL.Query().Get("command")
	var command []string
	err := json.Unmarshal([]byte(cmdJsonStr), &command)
	if err != nil {
		// this shouldn't happen, []string is always be serializable to json
		return nil, fmt.Errorf("failed to marshal command into json: %v", err)
	}

	ttyB := false
	if tty := req.URL.Query().Get("tty"); tty != "" {
		ttyB, err = strconv.ParseBool(tty)
		if err != nil {
			return nil, fmt.Errorf("tty value is not a boolean: %v", err)
		}
	}

	args := cstructs.AllocExecRequest{
		AllocID: allocID,
		Task:    task,
		Cmd:     command,
		Tty:     ttyB,
	}
	s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)

	conn, err := s.wsUpgrader.Upgrade(resp, req, nil)
	if err != nil {
		return nil, fmt.Errorf("failed to upgrade connection: %v", err)
	}

	if err := readWsHandshake(conn.ReadJSON, req, &args.QueryOptions); err != nil {
		conn.WriteMessage(websocket.CloseMessage,
			websocket.FormatCloseMessage(toWsCode(400), err.Error()))
		return nil, err
	}

	return s.execStream(conn, &args)
}

// readWsHandshake reads the websocket handshake message and sets
// query authentication token, if request requires a handshake
func readWsHandshake(readFn func(interface{}) error, req *http.Request, q *structs.QueryOptions) error {

	// Avoid handshake if request doesn't require one
	if hv := req.URL.Query().Get("ws_handshake"); hv == "" {
		return nil
	} else if h, err := strconv.ParseBool(hv); err != nil {
		return fmt.Errorf("ws_handshake value is not a boolean: %v", err)
	} else if !h {
		return nil
	}

	var h wsHandshakeMessage
	err := readFn(&h)
	if err != nil {
		return err
	}

	supportedWSHandshakeVersion := 1
	if h.Version != supportedWSHandshakeVersion {
		return fmt.Errorf("unexpected handshake value: %v", h.Version)
	}

	q.AuthToken = h.AuthToken
	return nil
}

type wsHandshakeMessage struct {
	Version   int    `json:"version"`
	AuthToken string `json:"auth_token"`
}

// execStream finds the appropriate RPC handler and then runs the bidirectional
// websocket-to-RPC stream
func (s *HTTPServer) execStream(ws *websocket.Conn, args *cstructs.AllocExecRequest) (any, error) {
	allocID := args.AllocID
	method := "Allocations.Exec"

	// Get the correct handler
	localClient, remoteClient, localServer := s.rpcHandlerForAlloc(allocID)
	var handler structs.StreamingRpcHandler
	var handlerErr error
	if localClient {
		handler, handlerErr = s.agent.Client().StreamingRpcHandler(method)
	} else if remoteClient {
		handler, handlerErr = s.agent.Client().RemoteStreamingRpcHandler(method)
	} else if localServer {
		handler, handlerErr = s.agent.Server().StreamingRpcHandler(method)
	}

	if handlerErr != nil {
		return nil, CodedError(500, handlerErr.Error())
	}

	return s.execStreamImpl(ws, args, handler)
}

// execStreamImpl is called by execStream with the appropriate RPC handler and
// then runs the bidirectional websocket-to-RPC stream.
func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest, handler structs.StreamingRpcHandler) (any, error) {

	// Create a pipe connecting the (possibly remote) handler to the http response
	httpPipe, handlerPipe := net.Pipe()
	decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
	encoder := codec.NewEncoder(httpPipe, structs.MsgpackHandle)

	// Create a goroutine that closes the pipe if the connection closes.
	ctx, cancel := context.WithCancel(context.Background())
	go func() {
		<-ctx.Done()
		httpPipe.Close()

		// don't close ws - wait to drain messages
	}()

	// Create a channel for the final result
	resultCh := make(chan HTTPCodedError, 1)

	// stream response back to the websocket: this should be the only goroutine
	// that writes to this websocket connection
	go func() {
		defer cancel()
		errCh := make(chan HTTPCodedError, 2)

		// Send the request
		if err := encoder.Encode(args); err != nil {
			resultCh <- s.execStreamHandleError(ws, CodedError(500, err.Error()))
			return
		}

		// only start this after we've tried to send the initial args
		go forwardExecInput(ctx, encoder, ws, errCh)

		for {
			select {
			case codedErr := <-errCh:
				resultCh <- s.execStreamHandleError(ws, codedErr)
				return
			default:
			}

			var res cstructs.StreamErrWrapper
			err := decoder.Decode(&res)
			if err != nil {
				errCh <- CodedError(500, err.Error())
				continue
			}
			decoder.Reset(httpPipe)

			if err := res.Error; err != nil {
				code := 500
				if err.Code != nil {
					code = int(*err.Code)
				}
				errCh <- CodedError(code, err.Error())
				continue
			}
			if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil {
				errCh <- CodedError(500, err.Error())
				continue
			}
		}
	}()

	// start streaming request to streaming RPC - returns when streaming
	// completes or errors
	handler(handlerPipe)

	// stop streaming background goroutines for streaming - but not websocket
	// activity
	cancel()

	// retrieve any error and/or wait until goroutine stop and close errCh
	// connection before closing websocket connection
	result := <-resultCh
	ws.Close()
	return nil, result
}

// execStreamHandleError writes a CloseMessage to the websocket if we get an
// error that isn't a "close error" caused by the RPC pipe finishing up. Note
// that this should *only* ever be called in the same goroutine as we're
// streaming the responses
func (s *HTTPServer) execStreamHandleError(ws *websocket.Conn, codedErr HTTPCodedError) HTTPCodedError {
	// we won't return an error on ws close, but at least make it available in
	// the logs so we can trace spurious disconnects
	s.logger.Trace("alloc exec channel closed with error", "error", codedErr)

	if isClosedError(codedErr) {
		return nil // we're intentionally throwing this error away
	} else if codedErr != nil {
		ws.WriteMessage(websocket.CloseMessage,
			websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error()))
		return codedErr
	}
	return nil
}

func toWsCode(httpCode int) int {
	switch httpCode {
	case 500:
		return websocket.CloseInternalServerErr
	default:
		// placeholder error code
		return websocket.ClosePolicyViolation
	}
}

// isClosedError checks if the websocket "error" is one of the benign "close" status codes
func isClosedError(err error) bool {
	if err == nil {
		return false
	}

	return errors.Is(err, io.EOF) ||
		errors.Is(err, io.ErrClosedPipe) ||
		err == io.ErrClosedPipe ||
		slices.ContainsFunc([]string{
			"closed", // msgpack decode error [pos 0]: io: read/write on closed pipe"
			"EOF",
			"close 1000", // CLOSE_NORMAL
			"close 1001", // CLOSE_GOING_AWAY
			"close 1005", // CLOSED_NO_STATUS
		}, func(s string) bool { return strings.Contains(err.Error(), s) })
}

// forwardExecInput forwards exec input (e.g. stdin) from websocket connection
// to the streaming RPC connection to client
func forwardExecInput(ctx context.Context, encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
	for {
		select {
		case <-ctx.Done():
			return
		default:
		}

		sf := &drivers.ExecTaskStreamingRequestMsg{}
		err := ws.ReadJSON(sf)
		if err == io.EOF {
			return
		}

		if err != nil {
			errCh <- CodedError(500, err.Error())
			return
		}

		err = encoder.Encode(sf)
		if err != nil {
			errCh <- CodedError(500, err.Error())
		}
	}
}
