diff --git a/cmd/cmd_root.go b/cmd/cmd_root.go index 5274ae8..fbeb552 100644 --- a/cmd/cmd_root.go +++ b/cmd/cmd_root.go @@ -30,12 +30,14 @@ var ( flagSessionToken string flagEnvFile string flagCreateDebugSession bool + flagLocal bool finalConfigFile string finalConcurrency string finalSessionToken string finalConfigValueSource string finalCreateDebugSession bool + finalLocal bool finalGraphFile string finalGraphArgs []string @@ -109,6 +111,8 @@ var cmdRoot = &cobra.Command{ }) finalCreateDebugSession = finalCreateDebugSessionStr == "true" || finalCreateDebugSessionStr == "1" + finalLocal = flagLocal + // the block below is used to distinguish between implicit graph files (eg if defined in an env var) + graph flags // vs explicit graph file (eg provided by positional arg) + graph flags. @@ -147,6 +151,13 @@ var cmdRoot = &cobra.Command{ return errors.New("when using --create-debug-session, a graph file must be specified") } + if finalLocal && finalSessionToken != "" { + return errors.New("--local and --session-token cannot be used together") + } + if finalLocal && finalCreateDebugSession { + return errors.New("--local and --create-debug-session cannot be used together") + } + return nil }, } @@ -155,6 +166,16 @@ func cmdRootRun(cmd *cobra.Command, args []string) { utils.SetConcurrencyEnabled(finalConcurrency == "" || finalConcurrency == "true" || finalConcurrency == "1") + // start a local WS server for local connections (eg the vscode extension) + if finalLocal { + err := sessions.RunLocalMode(finalConfigFile) + if err != nil { + utils.LogErr.Print(err.Error()) + os.Exit(1) + } + return + } + // if we still have no graph file, go to Session Mode if finalGraphFile == "" || finalCreateDebugSession { trapfn := func() { @@ -231,6 +252,7 @@ func init() { cmdRoot.Flags().StringVar(&flagConcurrency, "concurrency", "", "Enable or disable concurrency") cmdRoot.Flags().StringVar(&flagSessionToken, "session-token", "", "The session token from your browser") cmdRoot.Flags().BoolVar(&flagCreateDebugSession, "create-debug-session", false, "Create a debug session by connecting to the web app") + cmdRoot.Flags().BoolVar(&flagLocal, "local", false, "Start a local WebSocket server for direct editor connection") // disable interspersed flag parsing to allow passing arbitrary flags to graphs. // it stops cobra from parsing flags once it hits positional argument diff --git a/nodes/dir-walk@v1.go b/nodes/dir-walk@v1.go index bcd775b..1c5d2cf 100644 --- a/nodes/dir-walk@v1.go +++ b/nodes/dir-walk@v1.go @@ -154,7 +154,7 @@ func walk(root string, opts walkOpts, pattern []string, items map[string]os.File }) } else { - entries, err := os.ReadDir(root) + entries, err := os.ReadDir(filepath.Clean(root)) if err != nil { return "", core.CreateErr(nil, err, "failed to read directory") } diff --git a/nodes/run@v1.go b/nodes/run@v1.go index b921d5c..3069685 100644 --- a/nodes/run@v1.go +++ b/nodes/run@v1.go @@ -249,6 +249,8 @@ func runCommand(c *core.ExecutionState, shell string, script *string, args []str curEnvMap["PYTHONIOENCODING"] = "utf-8" args = append([]string{scriptPath}, args...) + default: + return "", 0, core.CreateErr(c, nil, "unsupported shell: %s", shell) } cmd = exec.Command(shell, args...) diff --git a/sessions/gateway.go b/sessions/gateway.go new file mode 100644 index 0000000..8512ec3 --- /dev/null +++ b/sessions/gateway.go @@ -0,0 +1,374 @@ +package sessions + +import ( + "bufio" + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "os/signal" + "runtime/debug" + "strings" + "sync" + "syscall" + + "github.com/actionforge/actrun-cli/build" + "github.com/actionforge/actrun-cli/utils" + "github.com/gorilla/websocket" +) + +func RunSessionMode(configFile string, graphFileForDebugSession string, sessionToken string, configValueSource string) error { + + if graphFileForDebugSession != "" && sessionToken != "" { + return errors.New("both createDebugSession and sessionToken cannot be set") + } + + if graphFileForDebugSession == "" { + PrintWelcomeMessage() + } + + if configFile != "" { + utils.LogOut.Infof("šŸ‘‰ Configs will be loaded from: %s\n", configFile) + _, err := utils.LoadConfig(configFile) + if err != nil { + return fmt.Errorf("error loading config: %v", err) // fmt.Errorf doesn't strictly need \n if returned as error + } + } else { + utils.LogOut.Info("No config file specified, config values will be derived from environment variables and flags\n") + } + + apiGatewayUrl := GetGatewayURL() + + wsScheme := "wss" + httpScheme := "https" + if apiGatewayUrl == "localhost" || strings.HasPrefix(apiGatewayUrl, "localhost:") { + wsScheme = "ws" + httpScheme = "http" + } + + var err error + if graphFileForDebugSession != "" { + sessionData, err := StartNewSession(httpScheme, apiGatewayUrl) + if err != nil { + return fmt.Errorf("error creating new debug session: %v", err) + } + sessionToken = sessionData.Token + + utils.LogOut.Infof("šŸ‘‰ Created new debug session for graph file: %s\n", graphFileForDebugSession) + utils.LogOut.Infof("Debug Session: %s\n", fmt.Sprintf("%s//%s/graph#%s", httpScheme, APP_URL, "")) + } else { + sessionToken, err = GetSessionToken(sessionToken, configValueSource) + if err != nil { + return fmt.Errorf("error reading session token: %v", err) + } + } + + if sessionToken == "" { + return fmt.Errorf("no session token provided, exiting.") + } + + // token validation and parsing + packet, err := base64.StdEncoding.DecodeString(sessionToken) + if err != nil { + return fmt.Errorf("invalid token string (not Base64): %v", err) + } + + if len(packet) < 38 { + return errors.New("invalid token (too short).") + } + + expectedChecksum := packet[len(packet)-4:] + dataPayload := packet[:len(packet)-4] + + idLength := int(packet[0]) + if idLength <= 0 || (1+idLength+32) > len(dataPayload) { + return fmt.Errorf("invalid token (malformed structure).") + } + + sessionIDBytes := packet[1 : 1+idLength] + keyBytes := packet[1+idLength : 1+idLength+32] + + dataToHash := append([]byte{}, sessionIDBytes...) + dataToHash = append(dataToHash, keyBytes...) + + hash := sha256.Sum256(dataToHash) + calculatedChecksum := hash[:4] + + if !bytes.Equal(expectedChecksum, calculatedChecksum) { + return fmt.Errorf("āŒ INTEGRITY CHECK FAILED: The token appears to be modified or typo'd.\nCheck the last few characters") + } + + sessionID := string(sessionIDBytes) + sharedKey := base64.StdEncoding.EncodeToString(keyBytes) + send := newEncryptedSender(sharedKey) + + uAddr := url.URL{Scheme: wsScheme, Host: apiGatewayUrl, Path: "/api/v2/ws/runner/" + sessionID} + utils.LogOut.Info("Connecting to Actionforge\n") + + ws, resp, err := websocket.DefaultDialer.Dial(uAddr.String(), nil) + if err != nil { + if resp != nil { + body, readErr := io.ReadAll(resp.Body) + if readErr == nil { + var errMsg map[string]string + if json.Unmarshal(body, &errMsg) == nil && errMsg["error"] != "" { + return fmt.Errorf("🚨 Error: %s", errMsg["error"]) + } + return fmt.Errorf("handshake failed (Status %s): %s", resp.Status, string(body)) + } + return fmt.Errorf("handshake failed: Server returned HTTP status: %s", resp.Status) + } + return fmt.Errorf("failed to connect to %v: %v", apiGatewayUrl, err) + } + defer ws.Close() + + utils.LogOut.Info("Successfully connected to your browser session. Waiting for commands...\n") + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + + // if browser disconnects during a --create-debug-session run, we switch to detached mode + // to ensure the graph finishes execution instead of hanging on a breakpoint. + var detachMu sync.Mutex + var detachedMode bool + + var ops debugOps + + // if browser disconnected override pause to ensure the graph finishes. + // Its the same behaviour if you detach a debugger in an IDE. + shouldSkipPause := func() bool { + detachMu.Lock() + defer detachMu.Unlock() + return detachedMode + } + + var onGraphComplete func() + if graphFileForDebugSession != "" { + onGraphComplete = func() { + done <- syscall.SIGTERM + } + } + + // cli auto start logic + if graphFileForDebugSession != "" { + graphContent, err := os.ReadFile(graphFileForDebugSession) + if err != nil { + return fmt.Errorf("failed to read debug graph file: %v", err) + } + + go func() { + graphContentBase64 := base64.URLEncoding.EncodeToString(graphContent) + + fragmentParams := url.Values{} + fragmentParams.Set("graph", graphContentBase64) + fragmentParams.Set("session_token", sessionToken) + + fragmentString := fragmentParams.Encode() + + utils.LogOut.Infof("šŸ‘‰ Debug Session: %s\n", fmt.Sprintf("%s://%s/graph#%s", httpScheme, APP_URL, fragmentString)) + + // Force StartPaused = true + triggerGraphExecution(&ops, ws, send, configFile, string(graphContent), nil, nil, nil, nil, true, false, shouldSkipPause, onGraphComplete) + }() + } + + // this is the main message loop + go func() { + defer func() { + if r := recover(); r != nil { + utils.LogOut.Errorf("recovered from panic in message loop: %v\n%s", r, debug.Stack()) + } + done <- syscall.SIGTERM + }() + + for { + var rawMsg EncryptedMessage + err := ws.ReadJSON(&rawMsg) + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + utils.LogOut.Debug("server closed connection cleanly.\n") + } else if strings.Contains(err.Error(), "use of closed network connection") { + // TODO: (Seb) check if there is a better way to handle this + // We reach this when the session shuts down and closes the socket + // while this loop is still waiting for a read. We just ignore it as + // its not really a bug + } else { + utils.LogOut.Warnf("WebSocket Error: %v\n", err) + } + break + } + + if rawMsg.Type == MsgTypeControl { + utils.LogOut.Debugf("received control message: %s\n", rawMsg.Payload) + + switch rawMsg.Payload { + case ControlBrowserDisconnected: + utils.LogOut.Debug("browser disconnected (waiting for reconnect...)\n") + + // if browser disconnected override pause to ensure the graph finishes + // its the same behaviour if you detach a debugger in an IDE + if graphFileForDebugSession != "" { + utils.LogOut.Debug("debug session detected: Resuming graph to completion...\n") + detachMu.Lock() + detachedMode = true + detachMu.Unlock() + + ops.dispatch(MsgTypeDebugResume, "") + } + + case ControlBrowserConnected: + utils.LogOut.Debug("browser connected. Checking for active debug state...\n") + ops.Lock() + if ops.cachedState != nil { + utils.LogOut.Debug("resending execution state to new browser connection...\n") + go send(ws, ops.cachedState) + } + ops.Unlock() + } + + continue + } + + if rawMsg.Type != MsgTypeData { + utils.LogOut.Warnf("Received non-data message type, ignoring: %v\n", rawMsg.Type) + continue + } + + decryptedJSON, err := decryptData(rawMsg.Payload, sharedKey) + if err != nil { + utils.LogOut.Errorf("dECRYPTION FAILED: %v", err) + send(ws, map[string]string{ + "type": MsgTypeJobError, + "error": "Decryption failed. Check your key.", + }) + continue + } + + var payload DecryptedPayload + if err := json.Unmarshal([]byte(decryptedJSON), &payload); err != nil { + utils.LogOut.Warnf("Failed to parse decrypted JSON: %v\n", err) + continue + } + + currentVer := build.Version + if isVersionOutdated(currentVer, payload.RequiredVersion) { + utils.LogOut.Warnf("Runner version %s is older than required %s\n", currentVer, payload.RequiredVersion) + send(ws, map[string]string{ + "type": MsgTypeWarning, + "message": fmt.Sprintf("WARNING: Runner version %s is older than required %s", currentVer, payload.RequiredVersion), + }) + } + + switch payload.Type { + + case MsgTypeRun: + triggerGraphExecution( + &ops, ws, send, configFile, + payload.Payload, + payload.Secrets, + payload.Inputs, + payload.Env, + payload.Breakpoints, + payload.StartPaused, + payload.IgnoreBreakpoints, + shouldSkipPause, + onGraphComplete, + ) + + case MsgTypeStop: + utils.LogOut.Debug("received stop signal\n") + send(ws, map[string]string{ + "type": MsgTypeLog, + "message": "Stop signal received. Attempting to cancel...", + }) + ops.cancelAndResume() + + case MsgTypeDebugStep, MsgTypeDebugStepInto, MsgTypeDebugStepOut, + MsgTypeDebugPause, MsgTypeDebugResume, + MsgTypeDebugAddBreakpoint, MsgTypeDebugRemoveBreakpoint: + ops.dispatch(payload.Type, payload.NodeID) + + default: + utils.LogOut.Debugf("unknown command type: %s\n", payload.Type) + } + } + }() + + <-done + utils.LogOut.Debug("shutting down runtime...\n") + + wsWriteMutex.Lock() + _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + wsWriteMutex.Unlock() + + return nil +} + +// GetSessionToken waits for the user to paste a token into standard input, +// reads it, trims it, and returns it. +// It returns the token (string) and any error encountered during reading. +func GetSessionToken(sessionToken string, configValueSource string) (string, error) { + fmt.Println() + fmt.Print("šŸ”‘ Enter session token: ") + + if sessionToken != "" { + fmt.Printf("\n\n", configValueSource) + return sessionToken, nil + } + + for { + + scanner := bufio.NewScanner(os.Stdin) + + if scanner.Scan() { + token := strings.TrimSpace(scanner.Text()) + + if token == "" || strings.EqualFold(token, "exit") || strings.EqualFold(token, "quit") { + return "", nil + } + + if len(token) < 16 { + fmt.Print(" Warning: That doesn't look like a valid session token. Please try again or type 'exit' to quit.\n") + fmt.Print("šŸ”‘ Enter session token: ") + continue + } + + return token, nil + } + + if err := scanner.Err(); err != nil { + return "", err + } + + return "", nil + } +} + +func PrintWelcomeMessage() { + welcomeText := `Welcome to your Actionforge Runner + +----------------------[ HOW TO RUN ]---------------------- + +[ šŸš€ OPTION 1: RUN LOCAL ACTION GRAPH ] + Execute a local graph file directly from your terminal. + Example: $ actrun my-graph.act + +[ šŸ”— OPTION 2: CONNECT TO WEB APP ] + Please paste the session token from your browser to connect. + +---------------------------------------------------------- + +šŸ“– Docs: https://docs.actionforge.dev + +` + + // Print the message to standard output. + // We use fmt.Print here instead of Println to avoid adding an extra + // newline at the very end, keeping the cursor right after the prompt. + fmt.Print(welcomeText) +} diff --git a/sessions/local.go b/sessions/local.go new file mode 100644 index 0000000..f0be978 --- /dev/null +++ b/sessions/local.go @@ -0,0 +1,182 @@ +package sessions + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "runtime/debug" + "strings" + "sync" + "syscall" + + "github.com/actionforge/actrun-cli/build" + "github.com/actionforge/actrun-cli/utils" + "github.com/gorilla/websocket" +) + +var localUpgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accept all origins for local connections + }, +} + +// RunLocalMode starts a local WebSocket server for direct editor connection (no gateway). +func RunLocalMode(configFile string) error { + + if configFile != "" { + utils.LogOut.Infof("šŸ‘‰ Configs will be loaded from: %s\n", configFile) + _, err := utils.LoadConfig(configFile) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + } else { + utils.LogOut.Info("No config file specified, config values will be derived from environment variables and flags\n") + } + + send := newPlainSender() + + // listen on a random available port on localhost + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return fmt.Errorf("failed to start local server: %v", err) + } + port := listener.Addr().(*net.TCPAddr).Port + + // print the port for the VS Code extension to capture + fmt.Printf("LOCAL_WS_PORT=%d\n", port) + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + + var wsConn *websocket.Conn + var wsConnMu sync.Mutex + + var ops debugOps + + mux := http.NewServeMux() + mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + ws, err := localUpgrader.Upgrade(w, r, nil) + if err != nil { + utils.LogOut.Errorf("failed to upgrade local WebSocket connection: %v\n", err) + return + } + + wsConnMu.Lock() + if wsConn != nil { + wsConnMu.Unlock() + _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage( + websocket.ClosePolicyViolation, + "Another client is already connected.", + )) + ws.Close() + return + } + wsConn = ws + wsConnMu.Unlock() + + utils.LogOut.Info("Editor connected via local WebSocket.\n") + + send(ws, map[string]string{ + "type": MsgTypeControl, + "message": "runner_connected", + "address": "127.0.0.1", + }) + + defer func() { + if r := recover(); r != nil { + utils.LogOut.Errorf("recovered from panic in local message loop: %v\n%s", r, debug.Stack()) + } + wsConnMu.Lock() + wsConn = nil + wsConnMu.Unlock() + ws.Close() + done <- syscall.SIGTERM + }() + + for { + _, msgBytes, err := ws.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + utils.LogOut.Debug("editor closed connection cleanly.\n") + } else if !strings.Contains(err.Error(), "use of closed network connection") { + utils.LogOut.Warnf("local WebSocket error: %v\n", err) + } + break + } + + var payload DecryptedPayload + if err := json.Unmarshal(msgBytes, &payload); err != nil { + utils.LogOut.Warnf("failed to parse JSON from editor: %v\n", err) + continue + } + + currentVer := build.Version + if isVersionOutdated(currentVer, payload.RequiredVersion) { + utils.LogOut.Warnf("Runner version %s is older than required %s\n", currentVer, payload.RequiredVersion) + send(ws, map[string]string{ + "type": MsgTypeWarning, + "message": fmt.Sprintf("WARNING: Runner version %s is older than required %s", currentVer, payload.RequiredVersion), + }) + } + + switch payload.Type { + + case MsgTypeRun: + triggerGraphExecution( + &ops, ws, send, configFile, + payload.Payload, + payload.Secrets, + payload.Inputs, + payload.Env, + payload.Breakpoints, + payload.StartPaused, + payload.IgnoreBreakpoints, + nil, nil, + ) + + case MsgTypeStop: + utils.LogOut.Debug("received stop signal\n") + send(ws, map[string]string{ + "type": MsgTypeLog, + "message": "Stop signal received. Attempting to cancel...", + }) + ops.cancelAndResume() + + case MsgTypeDebugStep, MsgTypeDebugStepInto, MsgTypeDebugStepOut, + MsgTypeDebugPause, MsgTypeDebugResume, + MsgTypeDebugAddBreakpoint, MsgTypeDebugRemoveBreakpoint: + ops.dispatch(payload.Type, payload.NodeID) + + default: + utils.LogOut.Debugf("unknown command type: %s\n", payload.Type) + } + } + }) + + server := &http.Server{Handler: mux} + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + utils.LogOut.Errorf("local HTTP server error: %v\n", err) + done <- syscall.SIGTERM + } + }() + + <-done + utils.LogOut.Debug("shutting down local runner...\n") + + wsConnMu.Lock() + if wsConn != nil { + wsWriteMutex.Lock() + _ = wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + wsWriteMutex.Unlock() + wsConn.Close() + } + wsConnMu.Unlock() + + server.Close() + + return nil +} diff --git a/sessions/protocol.go b/sessions/protocol.go new file mode 100644 index 0000000..c01c6cf --- /dev/null +++ b/sessions/protocol.go @@ -0,0 +1,689 @@ +package sessions + +import ( + "bufio" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/actionforge/actrun-cli/core" + "github.com/actionforge/actrun-cli/utils" + "github.com/gorilla/websocket" +) + +var wsWriteMutex sync.Mutex + +const ( + // Message Types (from browser) + MsgTypeRun = "run" + MsgTypeStop = "stop" + MsgTypeDebugPause = "debug_pause" + MsgTypeDebugResume = "debug_resume" + MsgTypeDebugStep = "debug_step" + MsgTypeDebugAddBreakpoint = "debug_add_breakpoint" + MsgTypeDebugRemoveBreakpoint = "debug_remove_breakpoint" + MsgTypeDebugStepInto = "debug_step_into" + MsgTypeDebugStepOut = "debug_step_out" + + // Message Types (to browser) + MsgTypeLog = "log" + MsgTypeLogError = "log_error" + MsgTypeJobFinished = "job_finished" + MsgTypeJobError = "job_error" + MsgTypeDebugState = "debug_state" + MsgTypeWarning = "warning" + + // Wrapper/Control Message Types (not E2E encrypted) + MsgTypeData = "data" // Wrapper for E2E encrypted payloads + MsgTypeControl = "control" // Server-to-runner control messages + + // Control Message Payloads + ControlBrowserDisconnected = "browser_disconnected" + ControlBrowserConnected = "browser_connected" +) + +func encryptData(plaintext string, base64Key string) (string, error) { + key, err := base64.StdEncoding.DecodeString(base64Key) + if err != nil { + return "", errors.New("failed to decode base64 key") + } + if len(key) != 32 { + return "", errors.New("invalid key length: must be 32 bytes (AES-256)") + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, aesgcm.NonceSize()) // NonceSize() is 12 bytes for AES-GCM + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + // Encrypt the data (nil prefix means append to nonce) + ciphertext := aesgcm.Seal(nil, nonce, []byte(plaintext), nil) + + ivAndCiphertext := append(nonce, ciphertext...) + + return base64.StdEncoding.EncodeToString(ivAndCiphertext), nil +} + +// MessageSender is a function that sends a payload over a WebSocket connection. +// Both encrypted (gateway) and plain (local) modes implement this signature. +type MessageSender func(ws *websocket.Conn, payload any) + +func newEncryptedSender(sharedKey string) MessageSender { + return func(ws *websocket.Conn, payload any) { + sendEncryptedJSON(ws, payload, sharedKey) + } +} + +func newPlainSender() MessageSender { + return func(ws *websocket.Conn, payload any) { + sendPlainJSON(ws, payload) + } +} + +func sendPlainJSON(ws *websocket.Conn, payload any) { + wsWriteMutex.Lock() + defer wsWriteMutex.Unlock() + + if err := ws.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { + utils.LogOut.Errorf("failed to set write deadline (connection likely closed): %v\n", err) + return + } + + if err := ws.WriteJSON(payload); err != nil { + utils.LogOut.Errorf("failed to send JSON message: %v\n", err) + } +} + +func sendEncryptedJSON(ws *websocket.Conn, payload any, sharedKey string) { + jsonPayload, err := json.Marshal(payload) + if err != nil { + utils.LogOut.Errorf("failed to marshal outgoing JSON: %v\n", err) + return + } + + encryptedPayload, err := encryptData(string(jsonPayload), sharedKey) + if err != nil { + utils.LogOut.Errorf("failed to encrypt outgoing message: %v\n", err) + return + } + + msg := EncryptedMessage{ + Type: MsgTypeData, + Payload: encryptedPayload, + } + + wsWriteMutex.Lock() + defer wsWriteMutex.Unlock() + + if err := ws.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { + utils.LogOut.Errorf("failed to set write deadline (connection likely closed): %v\n", err) + return + } + + if err := ws.WriteJSON(msg); err != nil { + utils.LogOut.Errorf("failed to send encrypted message: %v\n", err) + } +} + +// EncryptedMessage is the raw message received from the WebSocket +type EncryptedMessage struct { + Type string `json:"type"` + Payload string `json:"payload"` // Base64-encoded (IV + Ciphertext) +} + +// DecryptedPayload is the structure of the data *after* decryption +type DecryptedPayload struct { + Type string `json:"type"` + Payload string `json:"payload"` // The graph JSON (if type is "run") + Secrets map[string]string `json:"secrets"` + Inputs map[string]any `json:"inputs"` + Env map[string]string `json:"env"` + IgnoreBreakpoints bool `json:"ignore_breakpoints"` + StartPaused bool `json:"start_paused"` + Breakpoints []string `json:"breakpoints"` + RequiredVersion string `json:"required_version"` + NodeID string `json:"nodeId"` +} + +// Global State +var ( + // Use a channel to signal that a graph is currently running + graphRunning = make(chan bool, 1) + // Mutex to protect access to the cancel function + cancelLock sync.Mutex + // Holds the cancel function for the *current* running graph + currentGraphCancel context.CancelFunc +) + +type debugOps struct { + sync.Mutex + pause func() + resume func() + step func() + stepInto func() + stepOut func() + addBreakpoint func(string) + removeBreakpoint func(string) + cachedState any +} + +func (d *debugOps) cleanup() { + d.Lock() + d.pause = nil + d.resume = nil + d.step = nil + d.stepInto = nil + d.stepOut = nil + d.addBreakpoint = nil + d.removeBreakpoint = nil + d.cachedState = nil + d.Unlock() +} + +func (d *debugOps) dispatch(msgType string, nodeID string) { + d.Lock() + var fn func() + var fnStr func(string) + switch msgType { + case MsgTypeDebugStep: + fn = d.step + case MsgTypeDebugStepInto: + fn = d.stepInto + case MsgTypeDebugStepOut: + fn = d.stepOut + case MsgTypeDebugPause: + fn = d.pause + case MsgTypeDebugResume: + fn = d.resume + case MsgTypeDebugAddBreakpoint: + fnStr = d.addBreakpoint + case MsgTypeDebugRemoveBreakpoint: + fnStr = d.removeBreakpoint + } + d.Unlock() + + if fn != nil { + fn() + } + if fnStr != nil { + fnStr(nodeID) + } +} + +func (d *debugOps) cancelAndResume() { + cancelLock.Lock() + if currentGraphCancel != nil { + currentGraphCancel() + } + cancelLock.Unlock() + + d.Lock() + resumeFn := d.resume + d.Unlock() + if resumeFn != nil { + resumeFn() + } +} + +func triggerGraphExecution( + ops *debugOps, + ws *websocket.Conn, + send MessageSender, + configFile string, + graphPayload string, + secrets map[string]string, + inputs map[string]any, + env map[string]string, + breakpoints []string, + startPaused bool, + ignoreBreakpoints bool, + shouldSkipPause func() bool, + onGraphComplete func(), +) { + select { + case graphRunning <- true: + ctx, cancel := context.WithCancel(context.Background()) + + cancelLock.Lock() + currentGraphCancel = cancel + cancelLock.Unlock() + + var debugMu sync.Mutex + debugCond := sync.NewCond(&debugMu) + + var bpMutex sync.RWMutex + activeBreakpoints := make(map[string]bool) + + type StepMode int + const ( + StepRun StepMode = iota + StepOver + StepInto + StepOut + ) + + currentStepMode := StepRun + stepReferenceDepth := 0 + + if len(breakpoints) > 0 { + bpMutex.Lock() + for _, bp := range breakpoints { + activeBreakpoints[bp] = true + } + bpMutex.Unlock() + } + + isPaused := startPaused + + // Setup control functions + ops.Lock() + + ops.pause = func() { + debugMu.Lock() + isPaused = true + currentStepMode = StepRun + debugMu.Unlock() + utils.LogOut.Debug("pausing execution...\n") + } + + ops.resume = func() { + debugMu.Lock() + isPaused = false + currentStepMode = StepRun + ops.Lock() + ops.cachedState = nil + ops.Unlock() + debugCond.Broadcast() + debugMu.Unlock() + utils.LogOut.Debug("resuming execution...\n") + } + + ops.step = func() { + debugMu.Lock() + currentStepMode = StepOver + debugMu.Unlock() + ops.Lock() + ops.cachedState = nil + ops.Unlock() + debugMu.Lock() + debugCond.Signal() + debugMu.Unlock() + utils.LogOut.Debug("stepping Over...\n") + } + + ops.stepInto = func() { + debugMu.Lock() + currentStepMode = StepInto + debugMu.Unlock() + ops.Lock() + ops.cachedState = nil + ops.Unlock() + debugMu.Lock() + debugCond.Signal() + debugMu.Unlock() + utils.LogOut.Debug("stepping Into...\n") + } + + ops.stepOut = func() { + debugMu.Lock() + currentStepMode = StepOut + debugMu.Unlock() + ops.Lock() + ops.cachedState = nil + ops.Unlock() + debugMu.Lock() + debugCond.Signal() + debugMu.Unlock() + utils.LogOut.Debug("stepping Out...\n") + } + + ops.addBreakpoint = func(nodeId string) { + bpMutex.Lock() + activeBreakpoints[nodeId] = true + bpMutex.Unlock() + utils.LogOut.Debugf("breakpoint added at %s\n", nodeId) + } + + ops.removeBreakpoint = func(nodeId string) { + bpMutex.Lock() + delete(activeBreakpoints, nodeId) + bpMutex.Unlock() + utils.LogOut.Debugf("breakpoint removed at %s\n", nodeId) + } + ops.Unlock() + + lastKnownDepth := 0 + + debugCb := func(ec *core.ExecutionState, nodeVisit core.ContextVisit) { + fullPath := nodeVisit.Node.GetFullPath() + currentDepth := calculateGraphDepth(fullPath) + utils.LogOut.Debugf("visiting %s | Paused: %v\n", fullPath, isPaused) + + bpMutex.RLock() + hasBreakpoint := activeBreakpoints[fullPath] + bpMutex.RUnlock() + + debugMu.Lock() + + if hasBreakpoint { + utils.LogOut.Debugf("hit explicit breakpoint at %s\n", fullPath) + isPaused = true + currentStepMode = StepRun + } else if !isPaused { + switch currentStepMode { + case StepInto: + isPaused = true + currentStepMode = StepRun + case StepOver: + if currentDepth <= stepReferenceDepth { + isPaused = true + currentStepMode = StepRun + } + case StepOut: + if currentDepth < stepReferenceDepth { + isPaused = true + currentStepMode = StepRun + } + } + } + + if isPaused { + lastKnownDepth = currentDepth + } + + if shouldSkipPause != nil && shouldSkipPause() { + isPaused = false + } + + if isPaused { + utils.LogOut.Infof("debugging paused at node: %s\n", fullPath) + + var rootEc *core.ExecutionState = ec + for rootEc.ParentExecution != nil { + rootEc = rootEc.ParentExecution + } + + debugState := map[string]any{ + "type": MsgTypeDebugState, + "fullPath": fullPath, + "executionContext": *rootEc, + } + + go send(ws, debugState) + + ops.Lock() + ops.cachedState = debugState + ops.Unlock() + + debugCond.Wait() + + stepReferenceDepth = lastKnownDepth + isPaused = false + } + + debugMu.Unlock() + } + + if ignoreBreakpoints { + activeBreakpoints = make(map[string]bool) + debugCb = nil + } + + go func() { + runGraphFromConn(ctx, graphPayload, core.RunOpts{ + ConfigFile: configFile, + OverrideSecrets: secrets, + OverrideInputs: inputs, + OverrideEnv: env, + Args: []string{}, + }, ws, send, debugCb) + + ops.cleanup() + + if onGraphComplete != nil { + onGraphComplete() + } + }() + + default: + utils.LogOut.Warn("Cannot run graph: another graph is already in progress.\n") + send(ws, map[string]string{ + "type": MsgTypeJobError, + "error": "A graph is already running.", + }) + } +} + +func runGraphFromConn(ctx context.Context, graphData string, opts core.RunOpts, ws *websocket.Conn, send MessageSender, debugCb core.DebugCallback) { + + // *must* release the lock when it's done + defer func() { + <-graphRunning + + // cleanup the cancel function so "stop" can't be called on a finished job + cancelLock.Lock() + currentGraphCancel = nil + cancelLock.Unlock() + }() + + origStdout := os.Stdout + origStderr := os.Stderr + origLogOutput := utils.LogOut.Out // <-- this is logruses original output + + rOut, wOut, errOut := os.Pipe() + if errOut != nil { + utils.LogOut.Debugf("failed to create pipe for stdout/log capture: %v\n", errOut) + send(ws, map[string]string{ + "type": MsgTypeJobError, + "error": fmt.Sprintf("Failed to capture stdout/log: %v", errOut), + }) + return + } + + rErr, wErr, errErr := os.Pipe() + if errErr != nil { + wOut.Close() + utils.LogOut.Debugf("failed to create pipe for stderr capture: %v\n", errErr) + send(ws, map[string]string{ + "type": MsgTypeJobError, + "error": fmt.Sprintf("Failed to capture stderr: %v", errErr), + }) + return + } + + os.Stdout = wOut + utils.LogOut.SetOutput(wOut) + + os.Stderr = wErr + + startTime := time.Now() + fmt.Printf("šŸš€ Task started...\n") + + var wg sync.WaitGroup + wg.Add(2) + + // for stdout + go func() { + defer wg.Done() + scanner := bufio.NewScanner(rOut) + for scanner.Scan() { + line := scanner.Text() + + if strings.TrimSpace(line) == "" { + continue + } + + // here we write to original console + fmt.Fprintln(origStdout, line) + + send(ws, map[string]string{ + "type": MsgTypeLog, + "message": fmt.Sprintf("[%s] %s", time.Now().Format("2006-01-02 15:04:05"), line), + }) + } + if err := scanner.Err(); err != nil { + utils.LogOut.Debugf("error reading from stdout/log pipe: %v\n", err) + } + }() + + // for stderr + go func() { + defer wg.Done() + scanner := bufio.NewScanner(rErr) + for scanner.Scan() { + line := scanner.Text() + + if strings.TrimSpace(line) == "" { + continue + } + + // here we write to original console + fmt.Fprintln(origStderr, line) + + send(ws, map[string]string{ + "type": MsgTypeLogError, + "message": line, + }) + } + if err := scanner.Err(); err != nil { + utils.LogOut.Debugf("error reading from stderr pipe: %v\n", err) + } + }() + + runErr := func() (err error) { + defer core.RecoverHandler(false) + return core.RunGraphFromString(ctx, "browser", graphData, core.RunOpts{ + ConfigFile: opts.ConfigFile, + OverrideSecrets: opts.OverrideSecrets, + OverrideInputs: opts.OverrideInputs, + OverrideEnv: opts.OverrideEnv, + Args: []string{}, + }, debugCb) + }() + + endTime := time.Now() + duration := endTime.Sub(startTime) + durationStr := fmt.Sprintf("%.2fs", duration.Seconds()) + + // we print this *before* closing the pipes, so it still gets captured + if runErr != nil { + fmt.Printf("\nāŒ Job failed. (Total time: %s)\n", durationStr) + } else { + fmt.Printf("\nāœ… Job succeeded. (Total time: %s)\n", durationStr) + } + + wOut.Close() + wErr.Close() + + os.Stdout = origStdout + os.Stderr = origStderr + utils.LogOut.SetOutput(origLogOutput) + + wg.Wait() + + // all output has already been streamed, including the summary line. + // now we just send the final status message. + if runErr != nil { + utils.LogOut.Debug("graph execution failed\n") + // send final error, even if error lines were already streamed + send(ws, map[string]string{ + "type": MsgTypeJobError, + "error": fmt.Sprintf("%#v", runErr), + }) + return // Exit, the deferred lock release will still run + } + + send(ws, map[string]string{ + "type": MsgTypeJobFinished, + }) +} + +// decryptData decrypts the Base64-encoded (IV + Ciphertext) string +func decryptData(base64Ciphertext string, base64Key string) (string, error) { + key, err := base64.StdEncoding.DecodeString(base64Key) + if err != nil { + return "", errors.New("failed to decode base64 key") + } + if len(key) != 32 { + return "", errors.New("invalid key length: must be 32 bytes (AES-256)") + } + + data, err := base64.StdEncoding.DecodeString(base64Ciphertext) + if err != nil { + return "", errors.New("failed to decode base64 ciphertext") + } + + // The browser prepends the 12-byte IV to the ciphertext + const ivSize = 12 + if len(data) <= ivSize { + return "", errors.New("invalid ciphertext length") + } + iv := data[:ivSize] + ciphertext := data[ivSize:] + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + plaintext, err := aesgcm.Open(nil, iv, ciphertext, nil) + if err != nil { + // Decryption failed (invalid key or tampered message) + return "", err + } + + return string(plaintext), nil +} + +func calculateGraphDepth(fullPath string) int { + if fullPath == "" { + return 0 + } + return strings.Count(fullPath, "/") +} + +func isVersionOutdated(current, required string) bool { + if required == "" { + return false + } + + // If the CLI is built locally or has a non-semver version like `dev` + // or something, skip the check to not block anyone + currentVer, err := semver.NewVersion(current) + if err != nil { + return false + } + + requiredVer, err := semver.NewVersion(required) + if err != nil { + return false + } + + return currentVer.LessThan(requiredVer) +} diff --git a/sessions/session.go b/sessions/session.go deleted file mode 100644 index 3568851..0000000 --- a/sessions/session.go +++ /dev/null @@ -1,1025 +0,0 @@ -package sessions - -import ( - "bufio" - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/url" - "os" - "os/signal" - "runtime/debug" - "strings" - "sync" - "syscall" - "time" - - "github.com/Masterminds/semver/v3" - "github.com/actionforge/actrun-cli/build" - "github.com/actionforge/actrun-cli/core" - "github.com/actionforge/actrun-cli/utils" - "github.com/gorilla/websocket" -) - -var wsWriteMutex sync.Mutex - -const ( - // Message Types (from browser) - MsgTypeRun = "run" - MsgTypeStop = "stop" - MsgTypeDebugPause = "debug_pause" - MsgTypeDebugResume = "debug_resume" - MsgTypeDebugStep = "debug_step" - MsgTypeDebugAddBreakpoint = "debug_add_breakpoint" - MsgTypeDebugRemoveBreakpoint = "debug_remove_breakpoint" - MsgTypeDebugStepInto = "debug_step_into" - MsgTypeDebugStepOut = "debug_step_out" - - // Message Types (to browser) - MsgTypeLog = "log" - MsgTypeLogError = "log_error" - MsgTypeJobFinished = "job_finished" - MsgTypeJobError = "job_error" - MsgTypeDebugState = "debug_state" - MsgTypeWarning = "warning" - - // Wrapper/Control Message Types (not E2E encrypted) - MsgTypeData = "data" // Wrapper for E2E encrypted payloads - MsgTypeControl = "control" // Server-to-runner control messages - - // Control Message Payloads - ControlBrowserDisconnected = "browser_disconnected" - ControlBrowserConnected = "browser_connected" -) - -func encryptData(plaintext string, base64Key string) (string, error) { - key, err := base64.StdEncoding.DecodeString(base64Key) - if err != nil { - return "", errors.New("failed to decode base64 key") - } - if len(key) != 32 { - return "", errors.New("invalid key length: must be 32 bytes (AES-256)") - } - - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonce := make([]byte, aesgcm.NonceSize()) // NonceSize() is 12 bytes for AES-GCM - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return "", err - } - - // Encrypt the data (nil prefix means append to nonce) - ciphertext := aesgcm.Seal(nil, nonce, []byte(plaintext), nil) - - ivAndCiphertext := append(nonce, ciphertext...) - - return base64.StdEncoding.EncodeToString(ivAndCiphertext), nil -} - -func sendEncryptedJSON(ws *websocket.Conn, payload any, sharedKey string) { - jsonPayload, err := json.Marshal(payload) - if err != nil { - utils.LogOut.Errorf("failed to marshal outgoing JSON: %v\n", err) - return - } - - encryptedPayload, err := encryptData(string(jsonPayload), sharedKey) - if err != nil { - utils.LogOut.Errorf("failed to encrypt outgoing message: %v\n", err) - return - } - - msg := EncryptedMessage{ - Type: MsgTypeData, - Payload: encryptedPayload, - } - - wsWriteMutex.Lock() - defer wsWriteMutex.Unlock() - - if err := ws.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { - utils.LogOut.Errorf("failed to set write deadline (connection likely closed): %v\n", err) - return - } - - if err := ws.WriteJSON(msg); err != nil { - utils.LogOut.Errorf("failed to send encrypted message: %v\n", err) - } -} - -// EncryptedMessage is the raw message received from the WebSocket -type EncryptedMessage struct { - Type string `json:"type"` - Payload string `json:"payload"` // Base64-encoded (IV + Ciphertext) -} - -// DecryptedPayload is the structure of the data *after* decryption -type DecryptedPayload struct { - Type string `json:"type"` - Payload string `json:"payload"` // The graph JSON (if type is "run") - Secrets map[string]string `json:"secrets"` - Inputs map[string]any `json:"inputs"` - Env map[string]string `json:"env"` - IgnoreBreakpoints bool `json:"ignore_breakpoints"` - StartPaused bool `json:"start_paused"` - Breakpoints []string `json:"breakpoints"` - RequiredVersion string `json:"required_version"` - NodeID string `json:"nodeId"` -} - -// Global State -var ( - // Use a channel to signal that a graph is currently running - graphRunning = make(chan bool, 1) - // Mutex to protect access to the cancel function - cancelLock sync.Mutex - // Holds the cancel function for the *current* running graph - currentGraphCancel context.CancelFunc -) - -func RunSessionMode(configFile string, graphFileForDebugSession string, sessionToken string, configValueSource string) error { - - if graphFileForDebugSession != "" && sessionToken != "" { - return errors.New("both createDebugSession and sessionToken cannot be set") - } - - if graphFileForDebugSession == "" { - PrintWelcomeMessage() - } - - if configFile != "" { - utils.LogOut.Infof("šŸ‘‰ Configs will be loaded from: %s\n", configFile) - _, err := utils.LoadConfig(configFile) - if err != nil { - return fmt.Errorf("error loading config: %v", err) // fmt.Errorf doesn't strictly need \n if returned as error - } - } else { - utils.LogOut.Info("No config file specified, config values will be derived from environment variables and flags") - } - - apiGatewayUrl := GetGatewayURL() - - wsScheme := "wss" - httpScheme := "https" - if apiGatewayUrl == "localhost" || strings.HasPrefix(apiGatewayUrl, "localhost:") { - wsScheme = "ws" - httpScheme = "http" - } - - var err error - if graphFileForDebugSession != "" { - sessionData, err := StartNewSession(httpScheme, apiGatewayUrl) - if err != nil { - return fmt.Errorf("error creating new debug session: %v", err) - } - sessionToken = sessionData.Token - - utils.LogOut.Infof("šŸ‘‰ Created new debug session for graph file: %s\n", graphFileForDebugSession) - utils.LogOut.Infof("Debug Session: %s\n", fmt.Sprintf("%s//%s/graph#%s", httpScheme, APP_URL, "")) - } else { - sessionToken, err = GetSessionToken(sessionToken, configValueSource) - if err != nil { - return fmt.Errorf("error reading session token: %v", err) - } - } - - if sessionToken == "" { - return fmt.Errorf("no session token provided, exiting.") - } - - // token validation and parsing - packet, err := base64.StdEncoding.DecodeString(sessionToken) - if err != nil { - return fmt.Errorf("invalid token string (not Base64): %v", err) - } - - if len(packet) < 38 { - return errors.New("invalid token (too short).") - } - - expectedChecksum := packet[len(packet)-4:] - dataPayload := packet[:len(packet)-4] - - idLength := int(packet[0]) - if idLength <= 0 || (1+idLength+32) > len(dataPayload) { - return fmt.Errorf("invalid token (malformed structure).") - } - - sessionIDBytes := packet[1 : 1+idLength] - keyBytes := packet[1+idLength : 1+idLength+32] - - dataToHash := append([]byte{}, sessionIDBytes...) - dataToHash = append(dataToHash, keyBytes...) - - hash := sha256.Sum256(dataToHash) - calculatedChecksum := hash[:4] - - if !bytes.Equal(expectedChecksum, calculatedChecksum) { - return fmt.Errorf("āŒ INTEGRITY CHECK FAILED: The token appears to be modified or typo'd.\nCheck the last few characters") - } - - sessionID := string(sessionIDBytes) - sharedKey := base64.StdEncoding.EncodeToString(keyBytes) - - uAddr := url.URL{Scheme: wsScheme, Host: apiGatewayUrl, Path: "/api/v2/ws/runner/" + sessionID} - utils.LogOut.Info("Connecting to Actionforge\n") - - ws, resp, err := websocket.DefaultDialer.Dial(uAddr.String(), nil) - if err != nil { - if resp != nil { - body, readErr := io.ReadAll(resp.Body) - if readErr == nil { - var errMsg map[string]string - if json.Unmarshal(body, &errMsg) == nil && errMsg["error"] != "" { - return fmt.Errorf("🚨 Error: %s", errMsg["error"]) - } - return fmt.Errorf("handshake failed (Status %s): %s", resp.Status, string(body)) - } - return fmt.Errorf("handshake failed: Server returned HTTP status: %s", resp.Status) - } - return fmt.Errorf("failed to connect to %v: %v", apiGatewayUrl, err) - } - defer ws.Close() - - utils.LogOut.Info("Successfully connected to your browser session. Waiting for commands...\n") - - done := make(chan os.Signal, 1) - signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) - - // if browser disconnects during a --create-debug-session run, we switch to detached mode - // to ensure the graph finishes execution instead of hanging on a breakpoint. - var detachMu sync.Mutex - var detachedMode bool - - // current debug op state - var currentDebugOps struct { - sync.Mutex - pause func() - resume func() - step func() - stepInto func() - stepOut func() - addBreakpoint func(string) - removeBreakpoint func(string) - cachedState any - } - - triggerGraphExecution := func( - graphPayload string, - secrets map[string]string, - inputs map[string]any, - env map[string]string, - breakpoints []string, - startPaused bool, - ignoreBreakpoints bool, - ) { - select { - case graphRunning <- true: - ctx, cancel := context.WithCancel(context.Background()) - - cancelLock.Lock() - currentGraphCancel = cancel - cancelLock.Unlock() - - var debugMu sync.Mutex - debugCond := sync.NewCond(&debugMu) - - var bpMutex sync.RWMutex - activeBreakpoints := make(map[string]bool) - - type StepMode int - const ( - StepRun StepMode = iota - StepOver - StepInto - StepOut - ) - - currentStepMode := StepRun - stepReferenceDepth := 0 - - if len(breakpoints) > 0 { - bpMutex.Lock() - for _, bp := range breakpoints { - activeBreakpoints[bp] = true - } - bpMutex.Unlock() - } - - isPaused := startPaused - - // Setup control functions - currentDebugOps.Lock() - - currentDebugOps.pause = func() { - debugMu.Lock() - isPaused = true - currentStepMode = StepRun - debugMu.Unlock() - utils.LogOut.Debug("pausing execution...\n") - } - - currentDebugOps.resume = func() { - debugMu.Lock() - isPaused = false - currentStepMode = StepRun - currentDebugOps.Lock() - currentDebugOps.cachedState = nil - currentDebugOps.Unlock() - debugCond.Broadcast() - debugMu.Unlock() - utils.LogOut.Debug("resuming execution...\n") - } - - currentDebugOps.step = func() { - debugMu.Lock() - currentStepMode = StepOver - debugMu.Unlock() - currentDebugOps.Lock() - currentDebugOps.cachedState = nil - currentDebugOps.Unlock() - debugMu.Lock() - debugCond.Signal() - debugMu.Unlock() - utils.LogOut.Debug("stepping Over...\n") - } - - currentDebugOps.stepInto = func() { - debugMu.Lock() - currentStepMode = StepInto - debugMu.Unlock() - currentDebugOps.Lock() - currentDebugOps.cachedState = nil - currentDebugOps.Unlock() - debugMu.Lock() - debugCond.Signal() - debugMu.Unlock() - utils.LogOut.Debug("stepping Into...\n") - } - - currentDebugOps.stepOut = func() { - debugMu.Lock() - currentStepMode = StepOut - debugMu.Unlock() - currentDebugOps.Lock() - currentDebugOps.cachedState = nil - currentDebugOps.Unlock() - debugMu.Lock() - debugCond.Signal() - debugMu.Unlock() - utils.LogOut.Debug("stepping Out...\n") - } - - currentDebugOps.addBreakpoint = func(nodeId string) { - bpMutex.Lock() - activeBreakpoints[nodeId] = true - bpMutex.Unlock() - utils.LogOut.Debugf("breakpoint added at %s\n", nodeId) - } - - currentDebugOps.removeBreakpoint = func(nodeId string) { - bpMutex.Lock() - delete(activeBreakpoints, nodeId) - bpMutex.Unlock() - utils.LogOut.Debugf("breakpoint removed at %s\n", nodeId) - } - currentDebugOps.Unlock() - - lastKnownDepth := 0 - - debugCb := func(ec *core.ExecutionState, nodeVisit core.ContextVisit) { - fullPath := nodeVisit.Node.GetFullPath() - currentDepth := calculateGraphDepth(fullPath) - utils.LogOut.Debugf("visiting %s | Paused: %v\n", fullPath, isPaused) - - bpMutex.RLock() - hasBreakpoint := activeBreakpoints[fullPath] - bpMutex.RUnlock() - - debugMu.Lock() - - if hasBreakpoint { - utils.LogOut.Debugf("hit explicit breakpoint at %s\n", fullPath) - isPaused = true - currentStepMode = StepRun - } else if !isPaused { - switch currentStepMode { - case StepInto: - isPaused = true - currentStepMode = StepRun - case StepOver: - if currentDepth <= stepReferenceDepth { - isPaused = true - currentStepMode = StepRun - } - case StepOut: - if currentDepth < stepReferenceDepth { - isPaused = true - currentStepMode = StepRun - } - } - } - - if isPaused { - lastKnownDepth = currentDepth - } - - // if browser disconnected override pause to ensure the graph finishes - // Its the same behaviour if you detach a debugger in an IDE - detachMu.Lock() - if detachedMode { - isPaused = false - } - detachMu.Unlock() - - if isPaused { - utils.LogOut.Infof("debugging paused at node: %s\n", fullPath) - - var rootEc *core.ExecutionState = ec - for rootEc.ParentExecution != nil { - rootEc = rootEc.ParentExecution - } - - debugState := map[string]any{ - "type": MsgTypeDebugState, - "fullPath": fullPath, - "executionContext": *rootEc, - } - - go sendEncryptedJSON(ws, debugState, sharedKey) - - currentDebugOps.Lock() - currentDebugOps.cachedState = debugState - currentDebugOps.Unlock() - - debugCond.Wait() - - stepReferenceDepth = lastKnownDepth - isPaused = false - } - - debugMu.Unlock() - } - - if ignoreBreakpoints { - activeBreakpoints = make(map[string]bool) - debugCb = nil - } - - go func() { - runGraphFromConn(ctx, graphPayload, core.RunOpts{ - ConfigFile: configFile, - OverrideSecrets: secrets, - OverrideInputs: inputs, - OverrideEnv: env, - Args: []string{}, - }, ws, sharedKey, debugCb) - - // Cleanup - currentDebugOps.Lock() - currentDebugOps.pause = nil - currentDebugOps.resume = nil - currentDebugOps.step = nil - currentDebugOps.stepInto = nil - currentDebugOps.stepOut = nil - currentDebugOps.addBreakpoint = nil - currentDebugOps.removeBreakpoint = nil - currentDebugOps.cachedState = nil - currentDebugOps.Unlock() - - // if this was a one-off debug session (initiated by --create-debug-session), exit the process when graph completes - if graphFileForDebugSession != "" { - done <- syscall.SIGTERM - } - }() - - default: - utils.LogOut.Warn("Cannot run graph: another graph is already in progress.\n") - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeJobError, - "error": "A graph is already running.", - }, sharedKey) - } - } - - // cli auto start logic - if graphFileForDebugSession != "" { - graphContent, err := os.ReadFile(graphFileForDebugSession) - if err != nil { - return fmt.Errorf("failed to read debug graph file: %v", err) - } - - go func() { - graphContentBase64 := base64.URLEncoding.EncodeToString(graphContent) - - fragmentParams := url.Values{} - fragmentParams.Set("graph", graphContentBase64) - fragmentParams.Set("session_token", sessionToken) - - fragmentString := fragmentParams.Encode() - - utils.LogOut.Infof("šŸ‘‰ Debug Session: %s\n", fmt.Sprintf("%s://%s/graph#%s", httpScheme, APP_URL, fragmentString)) - - // Force StartPaused = true - triggerGraphExecution(string(graphContent), nil, nil, nil, nil, true, false) - }() - } - - // this is the main message loop - go func() { - defer func() { - if r := recover(); r != nil { - utils.LogOut.Errorf("recovered from panic in message loop: %v\n%s", r, debug.Stack()) - } - done <- syscall.SIGTERM - }() - - for { - var rawMsg EncryptedMessage - err := ws.ReadJSON(&rawMsg) - if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - utils.LogOut.Debug("server closed connection cleanly.\n") - } else if strings.Contains(err.Error(), "use of closed network connection") { - // TODO: (Seb) check if there is a better way to handle this - // We reach this when the session shuts down and closes the socket - // while this loop is still waiting for a read. We just ignore it as - // its not really a bug - } else { - utils.LogOut.Warnf("WebSocket Error: %v\n", err) - } - break - } - - if rawMsg.Type == MsgTypeControl { - utils.LogOut.Debugf("received control message: %s\n", rawMsg.Payload) - - switch rawMsg.Payload { - case ControlBrowserDisconnected: - utils.LogOut.Debug("browser disconnected (waiting for reconnect...)\n") - - // if browser disconnected override pause to ensure the graph finishes - // its the same behaviour if you detach a debugger in an IDE - if graphFileForDebugSession != "" { - utils.LogOut.Debug("debug session detected: Resuming graph to completion...\n") - detachMu.Lock() - detachedMode = true - detachMu.Unlock() - - currentDebugOps.Lock() - resumeFn := currentDebugOps.resume - currentDebugOps.Unlock() - - if resumeFn != nil { - resumeFn() - } - } - - case ControlBrowserConnected: - utils.LogOut.Debug("browser connected. Checking for active debug state...\n") - currentDebugOps.Lock() - if currentDebugOps.cachedState != nil { - utils.LogOut.Debug("resending execution state to new browser connection...\n") - go sendEncryptedJSON(ws, currentDebugOps.cachedState, sharedKey) - } - currentDebugOps.Unlock() - } - - continue - } - - if rawMsg.Type != MsgTypeData { - utils.LogOut.Warnf("Received non-data message type, ignoring: %v\n", rawMsg.Type) - continue - } - - decryptedJSON, err := decryptData(rawMsg.Payload, sharedKey) - if err != nil { - utils.LogOut.Errorf("dECRYPTION FAILED: %v", err) - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeJobError, - "error": "Decryption failed. Check your key.", - }, sharedKey) - continue - } - - var payload DecryptedPayload - if err := json.Unmarshal([]byte(decryptedJSON), &payload); err != nil { - utils.LogOut.Warnf("Failed to parse decrypted JSON: %v\n", err) - continue - } - - currentVer := build.Version - if isVersionOutdated(currentVer, payload.RequiredVersion) { - utils.LogOut.Warnf("Runner version %s is older than required %s\n", currentVer, payload.RequiredVersion) - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeWarning, - "message": fmt.Sprintf("WARNING: Runner version %s is older than required %s", currentVer, payload.RequiredVersion), - }, sharedKey) - } - - switch payload.Type { - - case MsgTypeRun: - triggerGraphExecution( - payload.Payload, - payload.Secrets, - payload.Inputs, - payload.Env, - payload.Breakpoints, - payload.StartPaused, - payload.IgnoreBreakpoints, - ) - - case MsgTypeStop: - utils.LogOut.Debug("received stop signal\n") - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeLog, - "message": "Stop signal received. Attempting to cancel...", - }, sharedKey) - - cancelLock.Lock() - if currentGraphCancel != nil { - currentGraphCancel() - } - cancelLock.Unlock() - - currentDebugOps.Lock() - resumeFn := currentDebugOps.resume - currentDebugOps.Unlock() - - if resumeFn != nil { - resumeFn() - } - - case MsgTypeDebugStep: - currentDebugOps.Lock() - stepFn := currentDebugOps.step - currentDebugOps.Unlock() - - if stepFn != nil { - stepFn() - } - - case MsgTypeDebugStepInto: - currentDebugOps.Lock() - stepIntoFn := currentDebugOps.stepInto - currentDebugOps.Unlock() - - if stepIntoFn != nil { - stepIntoFn() - } - - case MsgTypeDebugStepOut: - currentDebugOps.Lock() - stepOutFn := currentDebugOps.stepOut - currentDebugOps.Unlock() - - if stepOutFn != nil { - stepOutFn() - } - - case MsgTypeDebugPause: - currentDebugOps.Lock() - pauseFn := currentDebugOps.pause - currentDebugOps.Unlock() - - if pauseFn != nil { - pauseFn() - } - - case MsgTypeDebugResume: - currentDebugOps.Lock() - resumeFn := currentDebugOps.resume - currentDebugOps.Unlock() - - if resumeFn != nil { - resumeFn() - } - - case MsgTypeDebugAddBreakpoint: - currentDebugOps.Lock() - addBpFn := currentDebugOps.addBreakpoint - currentDebugOps.Unlock() - - if addBpFn != nil { - addBpFn(payload.NodeID) - } - - case MsgTypeDebugRemoveBreakpoint: - currentDebugOps.Lock() - removeBpFn := currentDebugOps.removeBreakpoint - currentDebugOps.Unlock() - - if removeBpFn != nil { - removeBpFn(payload.NodeID) - } - - default: - utils.LogOut.Debugf("unknown command type: %s\n", payload.Type) - } - } - }() - - <-done - utils.LogOut.Debug("shutting down runtime...\n") - - wsWriteMutex.Lock() - _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - wsWriteMutex.Unlock() - - return nil -} - -// GetSessionToken waits for the user to paste a token into standard input, -// reads it, trims it, and returns it. -// It returns the token (string) and any error encountered during reading. -func GetSessionToken(sessionToken string, configValueSource string) (string, error) { - fmt.Println() - fmt.Print("šŸ”‘ Enter session token: ") - - if sessionToken != "" { - fmt.Printf("\n\n", configValueSource) - return sessionToken, nil - } - - for { - - scanner := bufio.NewScanner(os.Stdin) - - if scanner.Scan() { - token := strings.TrimSpace(scanner.Text()) - - if token == "" || strings.EqualFold(token, "exit") || strings.EqualFold(token, "quit") { - return "", nil - } - - if len(token) < 16 { - fmt.Print(" Warning: That doesn't look like a valid session token. Please try again or type 'exit' to quit.\n") - fmt.Print("šŸ”‘ Enter session token: ") - continue - } - - return token, nil - } - - if err := scanner.Err(); err != nil { - return "", err - } - - return "", nil - } -} - -func PrintWelcomeMessage() { - welcomeText := `Welcome to your Actionforge Runner - -----------------------[ HOW TO RUN ]---------------------- - -[ šŸš€ OPTION 1: RUN LOCAL ACTION GRAPH ] - Execute a local graph file directly from your terminal. - Example: $ actrun my-graph.act - -[ šŸ”— OPTION 2: CONNECT TO WEB APP ] - Please paste the session token from your browser to connect. - ----------------------------------------------------------- - -šŸ“– Docs: https://docs.actionforge.dev - -` - - // Print the message to standard output. - // We use fmt.Print here instead of Println to avoid adding an extra - // newline at the very end, keeping the cursor right after the prompt. - fmt.Print(welcomeText) -} - -func runGraphFromConn(ctx context.Context, graphData string, opts core.RunOpts, ws *websocket.Conn, sharedKey string, debugCb core.DebugCallback) { - - // *must* release the lock when it's done - defer func() { - <-graphRunning - - // cleanup the cancel function so "stop" can't be called on a finished job - cancelLock.Lock() - currentGraphCancel = nil - cancelLock.Unlock() - }() - - origStdout := os.Stdout - origStderr := os.Stderr - origLogOutput := utils.LogOut.Out // <-- this is logruses original output - - rOut, wOut, errOut := os.Pipe() - if errOut != nil { - utils.LogOut.Debugf("failed to create pipe for stdout/log capture: %v\n", errOut) - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeJobError, - "error": fmt.Sprintf("Failed to capture stdout/log: %v", errOut), - }, sharedKey) - return - } - - rErr, wErr, errErr := os.Pipe() - if errErr != nil { - wOut.Close() - utils.LogOut.Debugf("failed to create pipe for stderr capture: %v\n", errErr) - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeJobError, - "error": fmt.Sprintf("Failed to capture stderr: %v", errErr), - }, sharedKey) - return - } - - os.Stdout = wOut - utils.LogOut.SetOutput(wOut) - - os.Stderr = wErr - - startTime := time.Now() - fmt.Printf("šŸš€ Task started...\n") - - var wg sync.WaitGroup - wg.Add(2) - - // for stdout - go func() { - defer wg.Done() - scanner := bufio.NewScanner(rOut) - for scanner.Scan() { - line := scanner.Text() - - if strings.TrimSpace(line) == "" { - continue - } - - // here we write to original console - fmt.Fprintln(origStdout, line) - - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeLog, - "message": fmt.Sprintf("[%s] %s", time.Now().Format("2006-01-02 15:04:05"), line), - }, sharedKey) - } - if err := scanner.Err(); err != nil { - utils.LogOut.Debugf("error reading from stdout/log pipe: %v\n", err) - } - }() - - // for stderr - go func() { - defer wg.Done() - scanner := bufio.NewScanner(rErr) - for scanner.Scan() { - line := scanner.Text() - - if strings.TrimSpace(line) == "" { - continue - } - - // here we write to original console - fmt.Fprintln(origStderr, line) - - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeLogError, - "message": line, - }, sharedKey) - } - if err := scanner.Err(); err != nil { - utils.LogOut.Debugf("error reading from stderr pipe: %v\n", err) - } - }() - - runErr := func() (err error) { - defer core.RecoverHandler(false) - return core.RunGraphFromString(ctx, "browser", graphData, core.RunOpts{ - ConfigFile: opts.ConfigFile, - OverrideSecrets: opts.OverrideSecrets, - OverrideInputs: opts.OverrideInputs, - OverrideEnv: opts.OverrideEnv, - Args: []string{}, - }, debugCb) - }() - - endTime := time.Now() - duration := endTime.Sub(startTime) - durationStr := fmt.Sprintf("%.2fs", duration.Seconds()) - - // we print this *before* closing the pipes, so it still gets captured - if runErr != nil { - fmt.Printf("\nāŒ Job failed. (Total time: %s)\n", durationStr) - } else { - fmt.Printf("\nāœ… Job succeeded. (Total time: %s)\n", durationStr) - } - - wOut.Close() - wErr.Close() - - os.Stdout = origStdout - os.Stderr = origStderr - utils.LogOut.SetOutput(origLogOutput) - - wg.Wait() - - // all output has already been streamed, including the summary line. - // now we just send the final status message. - if runErr != nil { - utils.LogOut.Debugf("graph execution failed: %v\n", runErr) - // send final error, even if error lines were already streamed - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeJobError, - "error": fmt.Sprintf("%#v", runErr), - }, sharedKey) - return // Exit, the deferred lock release will still run - } - - sendEncryptedJSON(ws, map[string]string{ - "type": MsgTypeJobFinished, - }, sharedKey) -} - -// decryptData decrypts the Base64-encoded (IV + Ciphertext) string -func decryptData(base64Ciphertext string, base64Key string) (string, error) { - key, err := base64.StdEncoding.DecodeString(base64Key) - if err != nil { - return "", errors.New("failed to decode base64 key") - } - if len(key) != 32 { - return "", errors.New("invalid key length: must be 32 bytes (AES-256)") - } - - data, err := base64.StdEncoding.DecodeString(base64Ciphertext) - if err != nil { - return "", errors.New("failed to decode base64 ciphertext") - } - - // The browser prepends the 12-byte IV to the ciphertext - const ivSize = 12 - if len(data) <= ivSize { - return "", errors.New("invalid ciphertext length") - } - iv := data[:ivSize] - ciphertext := data[ivSize:] - - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - plaintext, err := aesgcm.Open(nil, iv, ciphertext, nil) - if err != nil { - // Decryption failed (invalid key or tampered message) - return "", err - } - - return string(plaintext), nil -} - -func calculateGraphDepth(fullPath string) int { - if fullPath == "" { - return 0 - } - return strings.Count(fullPath, "/") -} - -func isVersionOutdated(current, required string) bool { - if required == "" { - return false - } - - // If the CLI is built locally or has a non-semver version like `dev` - // or something, skip the check to not block anyone - currentVer, err := semver.NewVersion(current) - if err != nil { - return false - } - - requiredVer, err := semver.NewVersion(required) - if err != nil { - return false - } - - return currentVer.LessThan(requiredVer) -} diff --git a/tests_e2e/references/reference_app.sh_l10 b/tests_e2e/references/reference_app.sh_l10 index 77872eb..d2dcb6c 100644 --- a/tests_e2e/references/reference_app.sh_l10 +++ b/tests_e2e/references/reference_app.sh_l10 @@ -16,6 +16,7 @@ Flags: --create-debug-session Create a debug session by connecting to the web app --env-file string Absolute path to an env file (.env) to load before execution -h, --help help for actrun + --local Start a local WebSocket server for direct editor connection --session-token string The session token from your browser -v, --version version for actrun diff --git a/tests_e2e/references/reference_app.sh_l11 b/tests_e2e/references/reference_app.sh_l11 index 5c82fb8..d815fb4 100644 --- a/tests_e2e/references/reference_app.sh_l11 +++ b/tests_e2e/references/reference_app.sh_l11 @@ -29,6 +29,7 @@ Welcome to your Actionforge Runner šŸ“– Docs: https:[REDACTED]/docs.actionforge.dev No config file specified, config values will be derived from environment variables and flags + šŸ”‘ Enter session token: invalid token string (not Base64): illegal base64 data at input byte 7 \ No newline at end of file diff --git a/tests_e2e/references/reference_app.sh_l28 b/tests_e2e/references/reference_app.sh_l28 index 352af2b..61b46a9 100644 --- a/tests_e2e/references/reference_app.sh_l28 +++ b/tests_e2e/references/reference_app.sh_l28 @@ -28,4 +28,5 @@ Welcome to your Actionforge Runner šŸ“– Docs: https:[REDACTED]/docs.actionforge.dev No config file specified, config values will be derived from environment variables and flags + šŸ”‘ Enter session token: no session token provided, exiting. \ No newline at end of file diff --git a/tests_e2e/references/reference_app.sh_l8 b/tests_e2e/references/reference_app.sh_l8 index 352af2b..61b46a9 100644 --- a/tests_e2e/references/reference_app.sh_l8 +++ b/tests_e2e/references/reference_app.sh_l8 @@ -28,4 +28,5 @@ Welcome to your Actionforge Runner šŸ“– Docs: https:[REDACTED]/docs.actionforge.dev No config file specified, config values will be derived from environment variables and flags + šŸ”‘ Enter session token: no session token provided, exiting. \ No newline at end of file diff --git a/tests_e2e/references/reference_contexts_env.sh_l26 b/tests_e2e/references/reference_contexts_env.sh_l26 index 993745d..954da9f 100644 --- a/tests_e2e/references/reference_contexts_env.sh_l26 +++ b/tests_e2e/references/reference_contexts_env.sh_l26 @@ -16,6 +16,7 @@ Flags: --create-debug-session Create a debug session by connecting to the web app --env-file string Absolute path to an env file (.env) to load before execution -h, --help help for actrun + --local Start a local WebSocket server for direct editor connection --session-token string The session token from your browser -v, --version version for actrun diff --git a/tests_e2e/references/reference_local_session.sh_l7 b/tests_e2e/references/reference_local_session.sh_l7 new file mode 100644 index 0000000..eda9bf0 --- /dev/null +++ b/tests_e2e/references/reference_local_session.sh_l7 @@ -0,0 +1,20 @@ +Cleaning up +Connecting to WebSocket +DEBUG PAUSED #1 at node: start +DEBUG PAUSED #2 at node: print-1 +DEBUG PAUSED #3 at node: const-str-1 +Job Finished Successfully! +Launching local runner +Log: created temp working directory for debug session: [REDACTED]/actrun-debug-[REDACTED] +Log: debugging paused at node: const-str-1 +Log: debugging paused at node: print-1 +Log: debugging paused at node: start +Log: step one +Log: step three +Log: step two +Log: āœ… Job succeeded. (Total time: ) +Log: šŸš€ Task started... +Runner connected! Sending Graph (Paused) +Sending RESUME command +Sending STEP command +Sending STEP command diff --git a/tests_e2e/scripts/debug_session.py b/tests_e2e/scripts/debug_session.py index 46c852d..bb52b93 100644 --- a/tests_e2e/scripts/debug_session.py +++ b/tests_e2e/scripts/debug_session.py @@ -244,6 +244,7 @@ async def main(): process.terminate() await process.wait() except ProcessLookupError: + # process already exited pass stdout_task.cancel() diff --git a/tests_e2e/scripts/local_session.act b/tests_e2e/scripts/local_session.act new file mode 100644 index 0000000..cf95e57 --- /dev/null +++ b/tests_e2e/scripts/local_session.act @@ -0,0 +1,91 @@ +editor: + version: + created: v1.34.0 +entry: start +type: generic +nodes: + - id: start + type: core/start@v1 + position: + x: 20 + y: 10 + - id: const-str-1 + type: core/const-string@v1 + position: + x: 30 + y: 210 + inputs: + value: step one + - id: const-str-2 + type: core/const-string@v1 + position: + x: 30 + y: 410 + inputs: + value: step two + - id: const-str-3 + type: core/const-string@v1 + position: + x: 30 + y: 610 + inputs: + value: step three + - id: print-1 + type: core/print@v1 + position: + x: 310 + y: 70 + inputs: + values[0]: null + - id: print-2 + type: core/print@v1 + position: + x: 310 + y: 270 + inputs: + values[0]: null + - id: print-3 + type: core/print@v1 + position: + x: 310 + y: 470 + inputs: + values[0]: null +connections: + - src: + node: const-str-1 + port: result + dst: + node: print-1 + port: values[0] + - src: + node: const-str-2 + port: result + dst: + node: print-2 + port: values[0] + - src: + node: const-str-3 + port: result + dst: + node: print-3 + port: values[0] +executions: + - src: + node: start + port: exec + dst: + node: print-1 + port: exec + - src: + node: print-1 + port: exec + dst: + node: print-2 + port: exec + - src: + node: print-2 + port: exec + dst: + node: print-3 + port: exec diff --git a/tests_e2e/scripts/local_session.py b/tests_e2e/scripts/local_session.py new file mode 100644 index 0000000..3cf1536 --- /dev/null +++ b/tests_e2e/scripts/local_session.py @@ -0,0 +1,142 @@ +import asyncio +import json +import os +import re +import websockets + + +ACTRUN_PATH = "actrun" + + +def clean_and_print(text): + if not text: + return + + timestamp_pattern = r'\[?\d{4}[/-]\d{2}[/-]\d{2}\s+\d{2}:\d{2}:\d{2}\]?' + duration_pattern = r'\d+(?:\.\d+)?s' + + text = re.sub(timestamp_pattern, "", text) + text = re.sub(duration_pattern, "", text) + text = re.sub(r'actrun-debug-\d+', 'actrun-debug-[REDACTED]', text) + + # remove empty lines left over from the redaction + lines = [line.strip() for line in text.splitlines() if line.strip()] + + print("\n".join(lines)) + + +async def drain_stream(stream): + """Read and discard stream output to prevent buffer blocking.""" + while True: + line = await stream.readline() + if not line: + break + + +async def main(): + graph_dir = os.environ.get("ACT_GRAPH_FILES_DIR", ".") + graph_path = os.path.join(graph_dir, "local_session.act") + + with open(graph_path, "r") as f: + graph_content = f.read() + + clean_and_print("Launching local runner") + + env = os.environ.copy() + env["ACT_NOCOLOR"] = "true" + env["ACT_LOGLEVEL"] = "warn" + + process = await asyncio.create_subprocess_exec( + ACTRUN_PATH, "--local", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + drain_out = None + drain_err = None + + try: + # Read stdout lines until we find LOCAL_WS_PORT + port = None + while True: + line = await asyncio.wait_for(process.stdout.readline(), timeout=10) + if not line: + clean_and_print("ERROR: Runner exited before printing port") + return + text = line.decode().strip() + match = re.search(r'LOCAL_WS_PORT=(\d+)', text) + if match: + port = int(match.group(1)) + break + + # Drain remaining subprocess output in background + drain_out = asyncio.create_task(drain_stream(process.stdout)) + drain_err = asyncio.create_task(drain_stream(process.stderr)) + + clean_and_print("Connecting to WebSocket") + + pause_count = 0 + + async with websockets.connect(f"ws://127.0.0.1:{port}/ws") as websocket: + async for message in websocket: + msg = json.loads(message) + msg_type = msg.get("type") + + if msg_type == "control": + if msg["message"] == "runner_connected": + clean_and_print("Runner connected! Sending Graph (Paused)") + + run_payload = { + "type": "run", + "payload": graph_content, + "start_paused": True, + "ignore_breakpoints": False, + "breakpoints": [], + } + await websocket.send(json.dumps(run_payload)) + + elif msg_type == "log": + clean_and_print(f"Log: {msg['message']}") + + elif msg_type == "log_error": + clean_and_print(f"LogError: {msg['message']}") + + elif msg_type == "debug_state": + pause_count += 1 + node = msg.get("fullPath", "unknown") + clean_and_print(f"DEBUG PAUSED #{pause_count} at node: {node}") + + await asyncio.sleep(0.2) + + if pause_count < 3: + clean_and_print("Sending STEP command") + await websocket.send(json.dumps({"type": "debug_step"})) + else: + clean_and_print("Sending RESUME command") + await websocket.send(json.dumps({"type": "debug_resume"})) + + elif msg_type == "job_finished": + clean_and_print("Job Finished Successfully!") + break + + elif msg_type == "job_error": + clean_and_print(f"Job Error: {msg.get('error', 'unknown')}") + break + + finally: + clean_and_print("Cleaning up") + try: + process.terminate() + await process.wait() + except ProcessLookupError: + # proxess already exited + pass + if drain_out: + drain_out.cancel() + if drain_err: + drain_err.cancel() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests_e2e/scripts/local_session.sh b/tests_e2e/scripts/local_session.sh new file mode 100644 index 0000000..a9eefa3 --- /dev/null +++ b/tests_e2e/scripts/local_session.sh @@ -0,0 +1,7 @@ +echo "Test Local Session" + +set -o pipefail + +$PYTHON_EXECUTABLE -m pip install websockets + +#! test $PYTHON_EXECUTABLE $ACT_GRAPH_FILES_DIR/local_session.py | sort diff --git a/tests_e2e/tests_e2e.py b/tests_e2e/tests_e2e.py index 424287e..8c76535 100644 --- a/tests_e2e/tests_e2e.py +++ b/tests_e2e/tests_e2e.py @@ -16,6 +16,8 @@ import re import platform import tempfile +import concurrent.futures +import io from pathlib import Path # Setup paths @@ -218,17 +220,19 @@ def run_test_script(root_path: str, script_file: str, working_dir: str): "PATH_SEPARATOR": os.sep }) - subprocess.run( + return subprocess.run( ["bash", to_posix_path(script_file)], shell=IS_WINDOWS, env=env, cwd=working_dir, - stdout=sys.stdout, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - check=True + text=True, + check=False ) def process_and_run_test(root_dir: str, source_script: str, ref_dir: str, cov_dir: str): + output = io.StringIO() temp_script_path = create_temp_script() redact_func = get_redact_function_script() script_name = os.path.basename(source_script) @@ -266,15 +270,13 @@ def process_and_run_test(root_dir: str, source_script: str, ref_dir: str, cov_di if stripped: if stripped.startswith("function"): fname = stripped.split()[1] if len(stripped.split()) > 1 else "unknown" - print(f"ā€¼ļø 'function' keyword is not POSIX compliant. Use '{fname}() {{' instead.") - sys.exit(1) + raise RuntimeError(f"'function' keyword is not POSIX compliant. Use '{fname}() {{' instead.") if not current_func_name and stripped.endswith("() {"): current_func_name = stripped.split()[0] elif stripped == "}": if not current_func_name: - print(f"ā€¼ļø Closing brace without function definition in {source_script}:{lineno}") - sys.exit(1) + raise RuntimeError(f"Closing brace without function definition in {source_script}:{lineno}") current_func_name = None elif not stripped.startswith("#"): # echo line if not inside a function definition @@ -284,13 +286,15 @@ def process_and_run_test(root_dir: str, source_script: str, ref_dir: str, cov_di dest.write(line) if current_func_name: - print(f"ā€¼ļø Function {current_func_name} was never closed.") - sys.exit(1) + raise RuntimeError(f"Function {current_func_name} was never closed.") tmp_cwd = tempfile.mkdtemp(prefix=f"actrun.{script_name}") - print(f"Running script: {source_script} -> {temp_script_path}:\n cwd: {tmp_cwd}\n") - run_test_script(root_dir, temp_script_path, tmp_cwd) + output.write(f"Running script: {source_script} -> {temp_script_path}:\n cwd: {tmp_cwd}\n\n") + result = run_test_script(root_dir, temp_script_path, tmp_cwd) + if result.stdout: + output.write(result.stdout) normalize_stack_trace_lines(ref_dir, script_name) + return output.getvalue(), result.returncode == 0 def compile_binaries(is_github_runner: bool): if is_github_runner: @@ -377,11 +381,42 @@ def main(): # Run Tests if target_test is None: - for script_path in collect_shell_scripts(scripts_dir): - process_and_run_test(base_cwd, script_path, ref_dir, cov_dir) + scripts = collect_shell_scripts(scripts_dir) else: - full_path = os.path.join(scripts_dir, target_test) - process_and_run_test(base_cwd, full_path, ref_dir, cov_dir) + scripts = [os.path.join(scripts_dir, target_test)] + + failed = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + future_to_script = { + executor.submit(process_and_run_test, base_cwd, script, ref_dir, cov_dir): script + for script in scripts + } + for future in concurrent.futures.as_completed(future_to_script): + script = future_to_script[future] + script_name = os.path.basename(script) + try: + test_output, success = future.result() + print(f"\n{'='*60}") + print(f" {script_name}") + print(f"{'='*60}") + print(test_output, end='') + if not success: + print(f"\n{Style.RED}ā€¼ļø FAILED: {script_name}{Style.RESET}") + failed.append(script) + else: + print(f"\n{Style.GREEN}āœ“ PASSED: {script_name}{Style.RESET}") + except Exception as e: + print(f"\n{'='*60}") + print(f" {script_name}") + print(f"{'='*60}") + print(f"{Style.RED}ā€¼ļø {script_name} failed with exception: {e}{Style.RESET}") + failed.append(script) + + if failed: + print(f"\n{Style.RED}ā€¼ļø {len(failed)} test(s) failed:{Style.RESET}") + for f in failed: + print(f" - {os.path.basename(f)}") + sys.exit(1) # check if there are any diffs between generated refs and committed/staged refs. # excludes reference files from other platforms (e.g., _linux files when running on darwin) @@ -399,8 +434,23 @@ def main(): print(f"Running git diff (excluding other platforms): {' '.join(git_cmd)}") res = subprocess.run(git_cmd, text=True, encoding='utf-8', capture_output=True, check=False) + # git diff only checked for changes so far, but we also want to check for untracked files + untracked_cmd = ['git', 'ls-files', '--others', '--exclude-standard', '--', ref_dir] + untracked_res = subprocess.run(untracked_cmd, text=True, encoding='utf-8', capture_output=True, check=False) + + # still filter out reference files from other platforms although they should never really + untracked_files = [f for f in untracked_res.stdout.splitlines()] + print(res.stdout) - if res.stdout: + has_diff = bool(res.stdout) + has_untracked = bool(untracked_files) + + if has_untracked: + print("untracked reference files:") + for f in untracked_files: + print(f" {f}") + + if has_diff or has_untracked: print("ā€¼ļø there are changes in the tests.") sys.exit(1) else: