diff --git a/.github/workflows/dev_module_build.yml b/.github/workflows/dev_module_build.yml index 0816e2953d..3671c91f26 100644 --- a/.github/workflows/dev_module_build.yml +++ b/.github/workflows/dev_module_build.yml @@ -242,6 +242,7 @@ jobs: find . \ -path ./images/cdi-cloner/cloner-startup -prune -o \ -path ./images/dvcr-artifact -prune -o \ + -path ./images/virtualization-dra -prune -o \ -path ./test/performance/shatal -prune -o \ -type f -name '.golangci.yaml' -printf '%h\0' | \ xargs -0 -n1 | sort -u @@ -279,6 +280,33 @@ jobs: exit 1 fi + lint_go-virtualization-dra: + runs-on: ubuntu-22.04 + name: Run go linter virtualization-dra + steps: + - name: Set up Go 1.25 + uses: actions/setup-go@v5 + with: + go-version: "1.25" + + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} + + - name: Install golangci-lint + run: | + echo "Installing golangci-lint..." + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.8.0 + echo "$(go env GOPATH)/bin" >> $GITHUB_PATH + echo "golangci-lint v2.8.0 installed successfully!" + + - name: Lint virtualization-dra directory with golangci-lint + shell: bash + working-directory: ./images/virtualization-dra + run: | + set -e + golangci-lint run + lint_yaml: runs-on: ubuntu-latest name: Run yaml linter diff --git a/Taskfile.yaml b/Taskfile.yaml index 6703c2ad6e..5f5f8c6dc7 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -283,19 +283,19 @@ tasks: kubectl -n d8-virtualization port-forward deploy/virt-api 2345:2345 EOF - dlv:virtualization-dra-plugin:build: - desc: "Build image virtualization-dra-plugin with dlv" + dlv:virtualization-dra-usb:build: + desc: "Build image virtualization-dra-usb with dlv" cmds: - - docker build --build-arg BRANCH=$BRANCH -f ./images/virtualization-dra-plugin/debug/dlv.Dockerfile -t "{{ .DLV_IMAGE }}" --platform linux/amd64 . + - docker build --build-arg BRANCH=$BRANCH -f ./images/virtualization-dra-usb/debug/dlv.Dockerfile -t "{{ .DLV_IMAGE }}" --platform linux/amd64 . - dlv:virtualization-dra-plugin:build-push: - desc: "Build and Push image virtualization-dra-plugin with dlv" + dlv:virtualization-dra-usb:build-push: + desc: "Build and Push image virtualization-dra-usb with dlv" cmds: - - task: dlv:virtualization-dra-plugin:build + - task: dlv:virtualization-dra-usb:build - docker push "{{ .DLV_IMAGE }}" - - task: dlv:virtualization-dra-plugin:print + - task: dlv:virtualization-dra-usb:print - dlv:virtualization-dra-plugin:print: + dlv:virtualization-dra-usb:print: desc: "Print commands for debug" env: IMAGE: "{{ .DLV_IMAGE }}" @@ -314,5 +314,5 @@ tasks: } } }' - kubectl -n d8-virtualization port-forward deploy/virtualization-dra 2345:2345 + kubectl -n d8-virtualization port-forward pod/ 2345:2345 EOF diff --git a/images/virt-artifact/werf.inc.yaml b/images/virt-artifact/werf.inc.yaml index 50ef78fb33..438a1f5ba3 100644 --- a/images/virt-artifact/werf.inc.yaml +++ b/images/virt-artifact/werf.inc.yaml @@ -9,6 +9,7 @@ image: {{ .ModuleNamePrefix }}{{ .ImageName }}-src-artifact final: false fromImage: builder/src +fromCacheVersion: "hotplug-38" # TODO: remove this secrets: - id: SOURCE_REPO value: {{ $.SOURCE_REPO }} diff --git a/images/virt-controller/debug/dlv.Dockerfile b/images/virt-controller/debug/dlv.Dockerfile index b9aa5d9ded..a5562eb25f 100644 --- a/images/virt-controller/debug/dlv.Dockerfile +++ b/images/virt-controller/debug/dlv.Dockerfile @@ -4,7 +4,7 @@ RUN go install github.com/go-delve/delve/cmd/dlv@latest ARG BRANCH="v1.6.2-virtualization" ENV VERSION="1.6.2" -ENV GOVERSION="1.23.0" +ENV GOVERSION="1.24.0" # Copy the git commits for rebuilding the image if the branch changes ADD "https://api.github.com/repos/deckhouse/3p-kubevirt/commits/$BRANCH" /.git-commit-hash.tmp @@ -26,7 +26,7 @@ ENV GOOS=linux ENV CGO_ENABLED=0 ENV GOARCH=amd64 -RUN go build -o /kubevirt-binaries/virt-controller ./cmd/virt-controller/ +RUN go build -gcflags="all=-N -l" -o /kubevirt-binaries/virt-controller ./cmd/virt-controller/ FROM busybox diff --git a/images/virtualization-artifact/pkg/common/patch/patch.go b/images/virtualization-artifact/pkg/common/patch/patch.go index 67462efb08..1bfadc688d 100644 --- a/images/virtualization-artifact/pkg/common/patch/patch.go +++ b/images/virtualization-artifact/pkg/common/patch/patch.go @@ -19,6 +19,7 @@ package patch import ( "encoding/json" "fmt" + "slices" "strings" ) @@ -78,18 +79,9 @@ func (jp *JSONPatch) Append(patches ...JSONPatchOperation) { } func (jp *JSONPatch) Delete(op, path string) { - var idx int - var found bool - for i, o := range jp.operations { - if o.Op == op && o.Path == path { - idx = i - found = true - break - } - } - if found { - jp.operations = append(jp.operations[:idx], jp.operations[idx+1:]...) - } + jp.operations = slices.DeleteFunc(jp.operations, func(o JSONPatchOperation) bool { + return o.Op == op && o.Path == path + }) } func (jp *JSONPatch) Len() int { diff --git a/images/virtualization-dra-usb/werf.inc.yaml b/images/virtualization-dra-usb/werf.inc.yaml index 67d793cbed..3d265577a6 100644 --- a/images/virtualization-dra-usb/werf.inc.yaml +++ b/images/virtualization-dra-usb/werf.inc.yaml @@ -8,6 +8,10 @@ import: add: /out/virtualization-dra-usb to: /app/virtualization-dra-usb after: install + - image: {{ .ModuleNamePrefix }}virtualization-dra-builder + add: /out/go-usbip + to: /usb/bin/go-usbip + after: install {{- if eq $.DEBUG_COMPONENT "delve/virtualization-dra-usb" }} - image: debugger add: /app/dlv diff --git a/images/virtualization-dra/.golangci.yaml b/images/virtualization-dra/.golangci.yaml index 59d9904f94..087244eed8 100644 --- a/images/virtualization-dra/.golangci.yaml +++ b/images/virtualization-dra/.golangci.yaml @@ -1,102 +1,61 @@ +# https://golangci-lint.run/usage/configuration/ +version: "2" + run: concurrency: 4 timeout: 10m + issues: # Show all errors. max-issues-per-linter: 0 max-same-issues: 0 exclude: - "don't use an underscore in package name" + output: sort-results: true -exclude-files: - - "^zz_generated.*" +exclusions: + paths: + - "^zz_generated.*" -linters-settings: - gofumpt: - extra-rules: true - gci: - sections: - - standard - - default - - prefix(github.com/deckhouse/) - goimports: - local-prefixes: github.com/deckhouse/ - errcheck: - exclude-functions: fmt:.*,[rR]ead|[wW]rite|[cC]lose,io:Copy - revive: - rules: - - name: dot-imports - disabled: true - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [funlen, gocognit, lll] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - importas: - # Do not allow unaliased imports of aliased packages. - # Default: false - no-unaliased: true - # Do not allow non-required aliases. - # Default: false - no-extra-aliases: false - # List of aliases - # Default: [] - alias: - - pkg: github.com/deckhouse/virtualization/api/core/v1alpha2 - alias: "" - - pkg: github.com/deckhouse/virtualization/api/subresources/v1alpha2 - alias: subv1alpha2 - - pkg: kubevirt.io/api/core/v1 - alias: virtv1 - - pkg: k8s.io/api/core/v1 - alias: corev1 - - pkg: k8s.io/api/authentication/v1 - alias: authnv1 - - pkg: k8s.io/api/storage/v1 - alias: storagev1 - - pkg: k8s.io/api/networking/v1 - alias: netv1 - - pkg: k8s.io/api/policy/v1 - alias: policyv1 - - pkg: k8s.io/apimachinery/pkg/apis/meta/v1 - alias: metav1 - - pkg: k8s.io/api/resource/v1 - alias: resourcev1 +formatters: + enable: + - gci + - gofmt + - gofumpt + - goimports + settings: + gci: + sections: + - standard + - default + - prefix(github.com/deckhouse/) + gofumpt: + extra-rules: true + goimports: + local-prefixes: github.com/deckhouse/ linters: - disable-all: true + default: none enable: - asciicheck # checks that your code does not contain non-ASCII identifiers - bidichk # checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # [maby too many false positives] checks the function whether use a non-inherited context + - contextcheck # [maybe too many false positives] checks the function whether use a non-inherited context - dogsled # checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 - copyloopvar # detects places where loop variables are copied (Go 1.22+) - - gci # controls golang package import order and makes it always deterministic - gocritic # provides diagnostics that check for bugs, performance and style issues - - gofmt # [replaced by goimports] checks whether code was gofmt-ed - - gofumpt # [replaced by goimports, gofumports is not available yet] checks whether code was gofumpt-ed - - goimports # in addition to fixing imports, goimports also formats your code in the same style as gofmt - - gosimple # specializes in simplifying a code - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - ineffassign # detects when assignments to existing variables are not used - misspell # finds commonly misspelled English words in comments - nolintlint # reports ill-formed or insufficient nolint directives - - reassign # Checks that package variables are not reassigned. + - reassign # checks that package variables are not reassigned - revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint - - stylecheck # is a replacement for golint - staticcheck # is a go vet on steroids, applying a ton of static analysis checks - - typecheck # like the front-end of a Go compiler, parses and type-checks Go code - testifylint # checks usage of github.com/stretchr/testify - unconvert # removes unnecessary type conversions - unparam # reports unused function parameters @@ -106,5 +65,63 @@ linters: - thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers - tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes - whitespace # detects leading and trailing whitespace - - wastedassign # Finds wasted assignment statements. + - wastedassign # finds wasted assignment statements - importas # checks import aliases against the configured convention + settings: + errcheck: + exclude-functions: + - "(*os.File).Close" + - "(*net.TCPConn).Close" + # - "fmt:.*" + # - "[rR]ead" + # - "[wW]rite" + # - "[cC]lose" + # - "io:Copy" + revive: + rules: + - name: dot-imports + disabled: true + - name: exported + disabled: true + - name: package-comments + disabled: true + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [funlen, gocognit, lll] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + importas: + # Do not allow unaliased imports of aliased packages. + # Default: false + no-unaliased: true + # Do not allow non-required aliases. + # Default: false + no-extra-aliases: false + # List of aliases + # Default: [] + alias: + - pkg: github.com/deckhouse/virtualization/api/core/v1alpha2 + alias: "" + - pkg: github.com/deckhouse/virtualization/api/subresources/v1alpha2 + alias: subv1alpha2 + - pkg: kubevirt.io/api/core/v1 + alias: virtv1 + - pkg: k8s.io/api/core/v1 + alias: corev1 + - pkg: k8s.io/api/authentication/v1 + alias: authnv1 + - pkg: k8s.io/api/storage/v1 + alias: storagev1 + - pkg: k8s.io/api/networking/v1 + alias: netv1 + - pkg: k8s.io/api/policy/v1 + alias: policyv1 + - pkg: k8s.io/apimachinery/pkg/apis/meta/v1 + alias: metav1 + - pkg: k8s.io/api/resource/v1 + alias: resourcev1 diff --git a/images/virtualization-dra/Taskfile.yaml b/images/virtualization-dra/Taskfile.yaml index a28e5823b4..a66fa8427e 100644 --- a/images/virtualization-dra/Taskfile.yaml +++ b/images/virtualization-dra/Taskfile.yaml @@ -52,8 +52,27 @@ tasks: lint:go: desc: "Run golangci-lint" - deps: - - _ensure:golangci-lint cmds: - | golangci-lint run + + lint:go:fix: + desc: "Run golangci-lint with --fix" + cmds: + - | + golangci-lint run --fix + + build:go-usbip: + desc: "Build go-usbip binary" + cmds: + - go build -o bin/go-usbip cmd/usb/go-usbip/main.go + + build:usb-monitor: + desc: "Build usb-monitor binary" + cmds: + - go build -o bin/usb-monitor cmd/usb/usb-monitor/main.go + + api:generate: + desc: "Generate API code" + cmds: + - hack/update-codegen.sh diff --git a/images/virtualization-dra/cmd/usb/dra/app/app.go b/images/virtualization-dra/cmd/usb/dra/app/app.go index ad141b23e5..11cd55cbd5 100644 --- a/images/virtualization-dra/cmd/usb/dra/app/app.go +++ b/images/virtualization-dra/cmd/usb/dra/app/app.go @@ -18,25 +18,33 @@ package app import ( "fmt" + "log/slog" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" + "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/clientcmd" "k8s.io/component-base/cli/flag" "github.com/deckhouse/virtualization-dra/internal/cdi" + "github.com/deckhouse/virtualization-dra/internal/featuregates" "github.com/deckhouse/virtualization-dra/internal/plugin" "github.com/deckhouse/virtualization-dra/internal/usb" + usbgateway "github.com/deckhouse/virtualization-dra/internal/usb-gateway" + "github.com/deckhouse/virtualization-dra/internal/usb-gateway/informer" "github.com/deckhouse/virtualization-dra/pkg/cli" + "github.com/deckhouse/virtualization-dra/pkg/controller" "github.com/deckhouse/virtualization-dra/pkg/libusb" "github.com/deckhouse/virtualization-dra/pkg/logger" + "github.com/deckhouse/virtualization-dra/pkg/usbip" ) func NewVirtualizationDraUSBCommand() *cobra.Command { o := &draOptions{ - logging: &logger.Options{}, - monitor: libusb.NewDefaultMonitorConfig(), + logging: &logger.Options{}, + monitor: libusb.NewDefaultMonitorConfig(), + usbipdConfig: &usbip.USBIPDConfig{}, } cmd := &cobra.Command{ @@ -56,24 +64,38 @@ func NewVirtualizationDraUSBCommand() *cobra.Command { fs.AddFlagSet(f) } + cmd.AddCommand(NewInitCommand()) + return cmd } type draOptions struct { - DriverName string - Kubeconfig string - Namespace string - NodeName string - CDIRoot string - HealthzPort int - - logging *logger.Options - monitor *libusb.MonitorConfig + DriverName string + Kubeconfig string + Namespace string + NodeName string + USBGatewaySecretName string + CDIRoot string + HealthzPort int + + logging *logger.Options + monitor *libusb.MonitorConfig + usbipdConfig *usbip.USBIPDConfig + + usbGatewayEnabled bool } func (o *draOptions) Complete() { log := o.logging.Complete() logger.SetDefaultLogger(log) + + o.usbGatewayEnabled = featuregates.Default().USBGatewayEnabled() + if o.usbGatewayEnabled { + if !o.usbipdConfig.ExportEnabled { + slog.Warn("USB gateway is enabled but USBIPD export is disabled. Enabling USBIPD export.") + } + o.usbipdConfig.ExportEnabled = true + } } func (o *draOptions) NamedFlags() (fs flag.NamedFlagSets) { @@ -82,12 +104,15 @@ func (o *draOptions) NamedFlags() (fs flag.NamedFlagSets) { mfs.StringVar(&o.Kubeconfig, "kubeconfig", cli.GetStringEnv("KUBECONFIG", ""), "Path to kubeconfig file") mfs.StringVar(&o.Namespace, "namespace", cli.GetStringEnv("NAMESPACE", ""), "Namespace") mfs.StringVar(&o.NodeName, "node-name", cli.GetStringEnv("NODE_NAME", ""), "Node name") + mfs.StringVar(&o.USBGatewaySecretName, "usb-gateway-secret-name", cli.GetStringEnv("USB_GATEWAY_SECRET_NAME", "virtualization-dra-usb-gateway"), "USB gateway secret name") mfs.StringVar(&o.CDIRoot, "cdi-root", cli.GetStringEnv("CDI_ROOT", cdi.SpecDir), "CDI root") mfs.IntVar(&o.HealthzPort, "healthz-port", cli.GetIntEnv("HEALTHZ_PORT", 51515), "Healthz port") o.logging.AddFlags(fs.FlagSet("logging")) o.monitor.AddFlags(fs.FlagSet("usb-monitor")) + o.usbipdConfig.AddFlags(fs.FlagSet("usbipd")) plugin.AddFlags(fs.FlagSet("plugin")) + featuregates.AddFlags(fs.FlagSet("feature-gates")) return fs } @@ -103,27 +128,38 @@ func (o *draOptions) Validate() error { return fmt.Errorf("cdiRoot is required") } + if o.usbGatewayEnabled { + if o.USBGatewaySecretName == "" { + return fmt.Errorf("USBGatewaySecretName is required") + } + } + return nil } -func (o *draOptions) Client() (kubernetes.Interface, error) { +func (o *draOptions) Clients() (kubernetes.Interface, dynamic.Interface, error) { cfg, err := clientcmd.BuildConfigFromFlags("", o.Kubeconfig) if err != nil { - return nil, fmt.Errorf("failed to get rest config: %w", err) + return nil, nil, fmt.Errorf("failed to get rest config: %w", err) } client, err := kubernetes.NewForConfig(cfg) if err != nil { - return nil, fmt.Errorf("failed to create kubernetes client: %w", err) + return nil, nil, fmt.Errorf("failed to create kubernetes client: %w", err) } - return client, nil + dynamicClient, err := dynamic.NewForConfig(cfg) + if err != nil { + return nil, nil, fmt.Errorf("failed to create dynamic client: %w", err) + } + + return client, dynamicClient, nil } func (o *draOptions) Run(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() - client, err := o.Client() + client, dynamicClient, err := o.Clients() if err != nil { return err } @@ -133,19 +169,73 @@ func (o *draOptions) Run(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create USB monitor: %w", err) } + var usbGateway usbgateway.USBGateway + group, ctx := errgroup.WithContext(ctx) + if o.usbGatewayEnabled { + usbipd, err := o.usbipdConfig.Complete(monitor) + if err != nil { + return fmt.Errorf("failed to create USBIPD: %w", err) + } + + f := informer.NewFactory(client, nil) + secretInformer := f.NamespacedSecret(o.Namespace) + resourceSliceInformer := f.ResourceSlice() + + group.Go(func() error { + return f.Run(ctx) + }) + f.WaitForCacheSync(ctx.Done()) + + usbGatewayController, err := usbgateway.NewUSBGatewayController( + ctx, + o.USBGatewaySecretName, + o.Namespace, + o.NodeName, + o.usbipdConfig.Address, + o.usbipdConfig.Port, + client, + secretInformer, + resourceSliceInformer, + usbip.New(), + ) + if err != nil { + return fmt.Errorf("failed to create USB gateway controller: %w", err) + } + + group.Go(func() error { + return usbipd.Run(ctx) + }) + + group.Go(func() error { + return controller.Run(usbGatewayController, ctx, 1) + }) + + marker := usbgateway.NewMarker(dynamicClient, o.NodeName) + if err = marker.Mark(ctx); err != nil { + return err + } + defer func() { + if err = marker.Unmark(ctx); err != nil { + slog.Error("failed to unmark node for USB gateway", slog.Any("error", err)) + } + }() + + usbGateway = usbGatewayController + } + usbCDIManager, err := cdi.NewManager(o.CDIRoot, "usb", o.DriverName, o.NodeName, "DRA_USB") if err != nil { return fmt.Errorf("failed to create CDI manager: %w", err) } - usbStore, err := usb.NewAllocationStore(ctx, o.NodeName, usbCDIManager, monitor) + usbStore, err := usb.NewAllocationStore(ctx, o.NodeName, usbCDIManager, monitor, usbGateway, client) if err != nil { return fmt.Errorf("failed to create USB store: %w", err) } - mgr, err := plugin.NewManager(o.DriverName, o.NodeName, client, usbStore, o.HealthzPort) + mgr, err := plugin.NewManager(o.DriverName, o.NodeName, client, usbStore, o.HealthzPort, o.usbGatewayEnabled) if err != nil { return fmt.Errorf("failed to create manager: %w", err) } diff --git a/images/virtualization-dra/cmd/usb/dra/app/init.go b/images/virtualization-dra/cmd/usb/dra/app/init.go new file mode 100644 index 0000000000..adecfe60bf --- /dev/null +++ b/images/virtualization-dra/cmd/usb/dra/app/init.go @@ -0,0 +1,89 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "fmt" + "log/slog" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/logger" + "github.com/deckhouse/virtualization-dra/pkg/modprobe" +) + +func NewInitCommand() *cobra.Command { + o := &initOptions{ + logging: &logger.Options{}, + } + + cmd := &cobra.Command{ + Use: "init", + Short: "Init USB gateway", + PreRun: func(_ *cobra.Command, _ []string) { + o.Complete() + }, + RunE: o.Run, + } + + return cmd +} + +type initOptions struct { + logging *logger.Options +} + +func (o *initOptions) Complete() { + log := o.logging.Complete() + logger.SetDefaultLogger(log) +} + +func (o *initOptions) Run(_ *cobra.Command, _ []string) error { + kernelRelease, err := modprobe.KernelRelease() + if err != nil { + return fmt.Errorf("failed to get kernel release: %w", err) + } + + slog.Info("Detected kernel release", slog.String("release", kernelRelease)) + + modules := []string{ + filepath.Join("/lib/modules", kernelRelease, "kernel/drivers/usb/usbip/usbip-core.ko"), + filepath.Join("/lib/modules", kernelRelease, "kernel/drivers/usb/usbip/usbip-host.ko"), + filepath.Join("/lib/modules", kernelRelease, "kernel/drivers/usb/usbip/vhci-hcd.ko"), + } + + zstSupported, err := modprobe.KernelSupportsZst(kernelRelease) + if err != nil { + return fmt.Errorf("failed to check kernel support for zst: %w", err) + } + if zstSupported { + for i := range modules { + modules[i] += ".zst" + } + } + + slog.Info("Loading modules", slog.Any("modules", modules)) + + if err := modprobe.LoadModules(modules...); err != nil { + return fmt.Errorf("failed to load modules: %w", err) + } + + slog.Info("Modules loaded successfully") + + return nil +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/app.go b/images/virtualization-dra/cmd/usb/go-usbip/app/app.go new file mode 100644 index 0000000000..0e8cda1674 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/app.go @@ -0,0 +1,95 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "encoding/json" + "fmt" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "sigs.k8s.io/yaml" +) + +const long = ` + _ _ + __ _ ___ _ _ ___| |__ (_)_ __ + / _' |/ _ \ _____| | | / __| '_ \| | '_ \ +| (_| | (_) |_____| |_| \__ \ |_) | | |_) | +\__, | \___/ \__,_|___/_.__/|_| .__/ +|___/ |_| + + go-usbip is a implementation of USBIP server and client. +` + +func NewUSBIPCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "usbip", + Short: "USBIP command line tool", + Long: long, + SilenceUsage: true, + SilenceErrors: true, + } + + cmd.AddCommand( + NewRunCommand(), + NewBindCommand(), + NewUnbindCommand(), + NewAttachCommand(), + NewDetachCommand(), + NewAttachInfoCommand(), + NewBindInfoCommand(), + NewInfoCommand(), + NewExportCommand(), + NewUnExportCommand(), + ) + + printer.AddFlags(cmd.PersistentFlags()) + + return cmd +} + +var printer = &printOptions{} + +type printOptions struct { + output string +} + +func (o *printOptions) AddFlags(fs *pflag.FlagSet) { + fs.StringVarP(&o.output, "output", "o", "json", "Output format") +} + +func (o *printOptions) PrintObject(cmd *cobra.Command, data interface{}) error { + switch o.output { + case "json": + b, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal json: %w", err) + } + cmd.Println(string(b)) + return nil + case "yaml": + b, err := yaml.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal yaml: %w", err) + } + cmd.Println(string(b)) + return nil + default: + return fmt.Errorf("unsupported format %q. Supported formats: [json, yaml]", o.output) + } +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/attach-info.go b/images/virtualization-dra/cmd/usb/go-usbip/app/attach-info.go new file mode 100644 index 0000000000..00bb41db46 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/attach-info.go @@ -0,0 +1,52 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewAttachInfoCommand() *cobra.Command { + o := &attachInfoOptions{} + cmd := &cobra.Command{ + Use: "attach-info", + Short: "Get attach info", + Example: o.Usage(), + RunE: o.Run, + } + + return cmd +} + +type attachInfoOptions struct{} + +func (o *attachInfoOptions) Usage() string { + return ` # Get attach info + $ go-usbip attach-info +` +} + +func (o *attachInfoOptions) Run(cmd *cobra.Command, _ []string) error { + infos, err := usbip.NewUSBAttacher().GetAttachInfo() + if err != nil { + return err + } + + return printer.PrintObject(cmd, infos) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/attach.go b/images/virtualization-dra/cmd/usb/go-usbip/app/attach.go new file mode 100644 index 0000000000..0967c8ea13 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/attach.go @@ -0,0 +1,60 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewAttachCommand() *cobra.Command { + o := &attachOptions{} + cmd := &cobra.Command{ + Use: "attach [:host:] [:busID:]", + Short: "Attach USB devices to USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.ExactArgs(2), + } + + o.AddFlags(cmd.Flags()) + + return cmd +} + +type attachOptions struct { + port int +} + +func (o *attachOptions) Usage() string { + return ` # Attach USB devices to USBIP server + $ go-usbip attach 192.168.1.1 3-1 +` +} + +func (o *attachOptions) AddFlags(fs *pflag.FlagSet) { + fs.IntVar(&o.port, "port", 3240, "Remote port for attaching") +} + +func (o *attachOptions) Run(_ *cobra.Command, args []string) error { + host := args[0] + busID := args[1] + _, err := usbip.NewUSBAttacher().Attach(host, busID, o.port) + return err +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/bind-info.go b/images/virtualization-dra/cmd/usb/go-usbip/app/bind-info.go new file mode 100644 index 0000000000..f8f8cf9439 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/bind-info.go @@ -0,0 +1,52 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewBindInfoCommand() *cobra.Command { + o := &bindInfoOptions{} + cmd := &cobra.Command{ + Use: "bind-info", + Short: "Get bind info", + Example: o.Usage(), + RunE: o.Run, + } + + return cmd +} + +type bindInfoOptions struct{} + +func (o *bindInfoOptions) Usage() string { + return ` # Get bind info + $ go-usbip bind-info +` +} + +func (o *bindInfoOptions) Run(cmd *cobra.Command, _ []string) error { + infos, err := usbip.NewUSBBinder().GetBindInfo() + if err != nil { + return err + } + + return printer.PrintObject(cmd, infos) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/bind.go b/images/virtualization-dra/cmd/usb/go-usbip/app/bind.go new file mode 100644 index 0000000000..5c5495f761 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/bind.go @@ -0,0 +1,49 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewBindCommand() *cobra.Command { + o := &bindOptions{} + cmd := &cobra.Command{ + Use: "bind [:busID:]", + Short: "Bind USB devices to USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.ExactArgs(1), + } + + return cmd +} + +type bindOptions struct{} + +func (o *bindOptions) Usage() string { + return ` # Bind USB devices to USBIP server + $ go-usbip bind 3-1 +` +} + +func (o *bindOptions) Run(_ *cobra.Command, args []string) error { + busID := args[0] + return usbip.NewUSBBinder().Bind(busID) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/detach.go b/images/virtualization-dra/cmd/usb/go-usbip/app/detach.go new file mode 100644 index 0000000000..a7f3fd8e31 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/detach.go @@ -0,0 +1,55 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "fmt" + "strconv" + + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewDetachCommand() *cobra.Command { + o := &detachOptions{} + cmd := &cobra.Command{ + Use: "detach [:port:]", + Short: "Detach USB devices from USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.ExactArgs(1), + } + + return cmd +} + +type detachOptions struct{} + +func (o *detachOptions) Usage() string { + return ` # Detach USB devices from USBIP server + $ go-usbip detach 0 +` +} + +func (o *detachOptions) Run(_ *cobra.Command, args []string) error { + port, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid port: %w", err) + } + return usbip.NewUSBAttacher().Detach(port) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/export.go b/images/virtualization-dra/cmd/usb/go-usbip/app/export.go new file mode 100644 index 0000000000..4d3c0e08ab --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/export.go @@ -0,0 +1,60 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewExportCommand() *cobra.Command { + o := &exportOptions{} + cmd := &cobra.Command{ + Use: "export [:host:] [:busID:]", + Short: "Export USB device on USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.ExactArgs(2), + } + + o.AddFlags(cmd.Flags()) + + return cmd +} + +type exportOptions struct { + port int +} + +func (o *exportOptions) Usage() string { + return ` # Export USB devices on USBIP server + $ go-usbip export 192.168.1.1 3-1 +` +} + +func (o *exportOptions) AddFlags(fs *pflag.FlagSet) { + fs.IntVar(&o.port, "port", 3240, "Remote port for exporting") +} + +func (o *exportOptions) Run(_ *cobra.Command, args []string) error { + host := args[0] + busID := args[1] + + return usbip.NewUSBExporter().Export(host, busID, o.port) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/info.go b/images/virtualization-dra/cmd/usb/go-usbip/app/info.go new file mode 100644 index 0000000000..e10094a235 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/info.go @@ -0,0 +1,65 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "cmp" + "slices" + + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/libusb" +) + +func NewInfoCommand() *cobra.Command { + o := &infoOptions{} + cmd := &cobra.Command{ + Use: "info", + Short: "Get info", + Example: o.Usage(), + RunE: o.Run, + } + + return cmd +} + +type infoOptions struct{} + +func (o *infoOptions) Usage() string { + return ` # Get info + $ go-usbip info +` +} + +func (o *infoOptions) Run(cmd *cobra.Command, _ []string) error { + discoverDevices, err := libusb.DiscoverPluggedUSBDevices() + if err != nil { + return err + } + + devices := make([]*libusb.USBDevice, 0, len(discoverDevices)) + + for _, device := range discoverDevices { + devices = append(devices, device) + } + + slices.SortFunc(devices, func(a, b *libusb.USBDevice) int { + return cmp.Compare(a.Path, b.Path) + }) + + return printer.PrintObject(cmd, devices) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/run.go b/images/virtualization-dra/cmd/usb/go-usbip/app/run.go new file mode 100644 index 0000000000..ad64272e1e --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/run.go @@ -0,0 +1,80 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "fmt" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/deckhouse/virtualization-dra/pkg/libusb" + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewRunCommand() *cobra.Command { + o := &runOptions{ + usbipdConfig: &usbip.USBIPDConfig{}, + monitor: libusb.NewDefaultMonitorConfig(), + } + cmd := &cobra.Command{ + Use: "run", + Short: "Run USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.NoArgs, + } + + o.AddFlags(cmd.Flags()) + + return cmd +} + +type runOptions struct { + usbipdConfig *usbip.USBIPDConfig + monitor *libusb.MonitorConfig +} + +func (o *runOptions) Usage() string { + return ` # Run USBIP server + $ go-usbip run +` +} + +func (o *runOptions) AddFlags(fs *pflag.FlagSet) { + o.usbipdConfig.AddFlags(fs) + o.monitor.AddFlags(fs) +} + +func (o *runOptions) Run(cmd *cobra.Command, _ []string) error { + monitor, err := o.monitor.Complete(cmd.Context(), nil) + if err != nil { + return fmt.Errorf("failed to create usb monitor: %w", err) + } + + usbipd, err := o.usbipdConfig.Complete(monitor) + if err != nil { + return fmt.Errorf("failed to create usbipd: %w", err) + } + + err = usbipd.Run(cmd.Context()) + if err != nil { + return fmt.Errorf("failed to run usbipd: %w", err) + } + + return nil +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/unbind.go b/images/virtualization-dra/cmd/usb/go-usbip/app/unbind.go new file mode 100644 index 0000000000..e557ebac15 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/unbind.go @@ -0,0 +1,49 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewUnbindCommand() *cobra.Command { + o := &unbindOptions{} + cmd := &cobra.Command{ + Use: "unbind [:busID:]", + Short: "Unbind USB devices from USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.ExactArgs(1), + } + + return cmd +} + +type unbindOptions struct{} + +func (o *unbindOptions) Usage() string { + return ` # Unbind USB devices from USBIP server + $ go-usbip unbind 3-1 +` +} + +func (o *unbindOptions) Run(_ *cobra.Command, args []string) error { + busID := args[0] + return usbip.NewUSBBinder().Unbind(busID) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/app/unexport.go b/images/virtualization-dra/cmd/usb/go-usbip/app/unexport.go new file mode 100644 index 0000000000..4d68b9ec8c --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/app/unexport.go @@ -0,0 +1,60 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +func NewUnExportCommand() *cobra.Command { + o := &unExportOptions{} + cmd := &cobra.Command{ + Use: "unexport [:host:] [:busID:]", + Short: "UnExport USB device on USBIP server", + Example: o.Usage(), + RunE: o.Run, + Args: cobra.ExactArgs(2), + } + + o.AddFlags(cmd.Flags()) + + return cmd +} + +type unExportOptions struct { + port int +} + +func (o *unExportOptions) Usage() string { + return ` # UnExport USB devices on USBIP server + $ go-usbip unexport 192.168.1.1 3-1 +` +} + +func (o *unExportOptions) AddFlags(fs *pflag.FlagSet) { + fs.IntVar(&o.port, "port", 3240, "Remote port for unexporting") +} + +func (o *unExportOptions) Run(_ *cobra.Command, args []string) error { + host := args[0] + busID := args[1] + + return usbip.NewUSBExporter().Unexport(host, busID, o.port) +} diff --git a/images/virtualization-dra/cmd/usb/go-usbip/main.go b/images/virtualization-dra/cmd/usb/go-usbip/main.go new file mode 100644 index 0000000000..ce41bb7c61 --- /dev/null +++ b/images/virtualization-dra/cmd/usb/go-usbip/main.go @@ -0,0 +1,29 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "os" + + "github.com/deckhouse/virtualization-dra/cmd/usb/go-usbip/app" + "github.com/deckhouse/virtualization-dra/pkg/cli" +) + +func main() { + code := cli.Main(app.NewUSBIPCommand()) + os.Exit(code) +} diff --git a/images/virtualization-dra/cmd/usb/usb-monitor/main.go b/images/virtualization-dra/cmd/usb/usb-monitor/main.go new file mode 100644 index 0000000000..e7b9564f7f --- /dev/null +++ b/images/virtualization-dra/cmd/usb/usb-monitor/main.go @@ -0,0 +1,106 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "encoding/json" + "log/slog" + "os" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/deckhouse/virtualization-dra/pkg/cli" + "github.com/deckhouse/virtualization-dra/pkg/libusb" + "github.com/deckhouse/virtualization-dra/pkg/logger" +) + +func main() { + code := cli.Main(NewUSBMonitorCommand()) + os.Exit(code) +} + +func NewUSBMonitorCommand() *cobra.Command { + o := &options{ + monitor: libusb.NewDefaultMonitorConfig(), + logging: &logger.Options{}, + } + + cmd := &cobra.Command{ + Use: "usb-monitor", + Short: "USB monitor", + SilenceUsage: true, + SilenceErrors: true, + PreRun: func(cmd *cobra.Command, args []string) { + o.Complete() + }, + RunE: o.Run, + } + + o.AddFlags(cmd.Flags()) + + return cmd +} + +type options struct { + monitor *libusb.MonitorConfig + logging *logger.Options +} + +func (o *options) Complete() { + log := o.logging.Complete() + logger.SetDefaultLogger(log) +} + +func (o *options) AddFlags(fs *pflag.FlagSet) { + o.monitor.AddFlags(fs) + o.logging.AddFlags(fs) +} + +func (o *options) Run(cmd *cobra.Command, _ []string) error { + monitor, err := o.monitor.Complete(cmd.Context(), nil) + if err != nil { + return err + } + + devices := monitor.GetDevices() + o.printDevices(cmd, devices) + + changes := monitor.DeviceChanges() + for { + select { + case <-cmd.Context().Done(): + return nil + case _, ok := <-changes: + if !ok { + return nil + } + slog.Info("USB devices changed") + devices = monitor.GetDevices() + o.printDevices(cmd, devices) + } + } +} + +func (o *options) printDevices(cmd *cobra.Command, devices []libusb.USBDevice) { + b, err := json.Marshal(devices) + if err != nil { + slog.Error("failed to marshal devices", slog.Any("err", err)) + return + } + cmd.Println(string(b)) +} diff --git a/images/virtualization-dra/go.mod b/images/virtualization-dra/go.mod index d24c9d9d3f..42600051b4 100644 --- a/images/virtualization-dra/go.mod +++ b/images/virtualization-dra/go.mod @@ -1,6 +1,6 @@ module github.com/deckhouse/virtualization-dra -go 1.24.7 +go 1.25.0 tool github.com/onsi/ginkgo/v2/ginkgo @@ -8,6 +8,7 @@ require ( github.com/containerd/nri v0.10.0 github.com/deckhouse/deckhouse/pkg/log v0.1.0 github.com/go-logr/logr v1.4.2 + github.com/klauspost/compress v1.18.0 github.com/onsi/ginkgo/v2 v2.21.0 github.com/onsi/gomega v1.35.1 github.com/spf13/cobra v1.10.1 @@ -23,12 +24,16 @@ require ( k8s.io/klog/v2 v2.130.1 k8s.io/kubelet v0.34.2 k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 + sigs.k8s.io/yaml v1.6.0 tags.cncf.io/container-device-interface v1.0.1 tags.cncf.io/container-device-interface/specs-go v1.0.0 ) require ( github.com/DataDog/gostackparse v0.7.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/blang/semver/v4 v4.0.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/ttrpc v1.2.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -56,11 +61,17 @@ require ( github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_golang v1.22.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect github.com/x448/float16 v0.8.4 // indirect go.etcd.io/etcd/client/pkg/v3 v3.6.4 // indirect + go.opentelemetry.io/otel v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect @@ -81,5 +92,4 @@ require ( sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect - sigs.k8s.io/yaml v1.6.0 // indirect ) diff --git a/images/virtualization-dra/go.sum b/images/virtualization-dra/go.sum index f4166b4773..1d52678be5 100644 --- a/images/virtualization-dra/go.sum +++ b/images/virtualization-dra/go.sum @@ -1,7 +1,11 @@ github.com/DataDog/gostackparse v0.7.0 h1:i7dLkXHvYzHV308hnkvVGDL3BR4FWl7IsXNPz/IGQh4= github.com/DataDog/gostackparse v0.7.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/nri v0.10.0 h1:bt2NzfvlY6OJE0i+fB5WVeGQEycxY7iFVQpEbh7J3Go= @@ -61,6 +65,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/knqyf263/go-plugin v0.9.0 h1:CQs2+lOPIlkZVtcb835ZYDEoyyWJWLbSTWeCs0EwTwI= github.com/knqyf263/go-plugin v0.9.0/go.mod h1:2z5lCO1/pez6qGo8CvCxSlBFSEat4MEp1DrnA+f7w8Q= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -70,6 +76,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mndrix/tap-go v0.0.0-20171203230836-629fa407e90b/go.mod h1:pzzDgJWZ34fGzaAZGFW22KVZDfyrYW+QABMrWnJBnSs= @@ -98,6 +106,12 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= diff --git a/images/virtualization-dra/internal/consts/consts.go b/images/virtualization-dra/internal/consts/consts.go new file mode 100644 index 0000000000..d40a2d15af --- /dev/null +++ b/images/virtualization-dra/internal/consts/consts.go @@ -0,0 +1,31 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package consts + +const ( + USBGatewayLabel = "virtualization.deckhouse.io/usb-gateway" +) + +const ( + VirtualizationDraUSBDriverName = "virtualization-usb" +) + +const ( + AnnUSBDeviceAddresses = "usb.virtualization.deckhouse.io/device-addresses" + AnnUSBDeviceUser = "usb.virtualization.deckhouse.io/device-user" + AnnUSBDeviceGroup = "usb.virtualization.deckhouse.io/device-group" +) diff --git a/images/virtualization-dra/internal/featuregates/featuregates.go b/images/virtualization-dra/internal/featuregates/featuregates.go new file mode 100644 index 0000000000..e61a9f2692 --- /dev/null +++ b/images/virtualization-dra/internal/featuregates/featuregates.go @@ -0,0 +1,85 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package featuregates + +import ( + "github.com/spf13/pflag" + "k8s.io/component-base/featuregate" +) + +const ( + USBGateway featuregate.Feature = "USBGateway" + USBNodeLocalMultiAllocation featuregate.Feature = "USBNodeLocalMultiAllocation" +) + +var featureSpecs = map[featuregate.Feature]featuregate.FeatureSpec{ + USBGateway: { + Default: false, + PreRelease: featuregate.Alpha, + }, + USBNodeLocalMultiAllocation: { + Default: false, + PreRelease: featuregate.Alpha, + }, +} + +var ( + instance *FeatureGate + addFlags func(fs *pflag.FlagSet) +) + +func init() { + gate, gateAddFlags, _, err := New() + if err != nil { + panic(err) + } + instance = gate + addFlags = gateAddFlags +} + +func AddFlags(fs *pflag.FlagSet) { + addFlags(fs) +} + +func Default() *FeatureGate { + return instance +} + +type ( + AddFlagsFunc func(fs *pflag.FlagSet) + SetFromMapFunc func(m map[string]bool) error +) + +func New() (*FeatureGate, AddFlagsFunc, SetFromMapFunc, error) { + gate := featuregate.NewFeatureGate() + if err := gate.Add(featureSpecs); err != nil { + return nil, nil, nil, err + } + return &FeatureGate{gate}, gate.AddFlag, gate.SetFromMap, nil +} + +type FeatureGate struct { + featuregate.FeatureGate +} + +func (f *FeatureGate) USBGatewayEnabled() bool { + return f.Enabled(USBGateway) +} + +func (f *FeatureGate) USBNodeLocalMultiAllocationEnabled() bool { + return f.Enabled(USBNodeLocalMultiAllocation) +} diff --git a/images/virtualization-dra/internal/plugin/driver.go b/images/virtualization-dra/internal/plugin/driver.go index 40686c42b5..2211821f86 100644 --- a/images/virtualization-dra/internal/plugin/driver.go +++ b/images/virtualization-dra/internal/plugin/driver.go @@ -27,11 +27,13 @@ import ( utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/kubeletplugin" + "k8s.io/dynamic-resource-allocation/resourceslice" "github.com/deckhouse/deckhouse/pkg/log" + "github.com/deckhouse/virtualization-dra/internal/plugin/wrapresourceslice" ) -func NewDriver(driverName, nodeName string, kubeClient kubernetes.Interface, allocator Allocator) (*Driver, error) { +func NewDriver(driverName, nodeName string, kubeClient kubernetes.Interface, allocator Allocator, shared bool) (*Driver, error) { if driverName == "" { return nil, fmt.Errorf("driver name is required") } @@ -40,11 +42,17 @@ func NewDriver(driverName, nodeName string, kubeClient kubernetes.Interface, all return nil, fmt.Errorf("failed to initialize plugin directory: %w", err) } + poolName := "" + if shared { + poolName = nodeName + } + return &Driver{ driverName: driverName, nodeName: nodeName, kubeClient: kubeClient, allocator: allocator, + poolName: poolName, log: slog.With(slog.String("driver", driverName), slog.String("component", "driver")), }, nil } @@ -54,11 +62,13 @@ func NewDriver(driverName, nodeName string, kubeClient kubernetes.Interface, all type Driver struct { driverName string nodeName string + poolName string kubeClient kubernetes.Interface allocator Allocator log *slog.Logger + publisher resourcePublisher helper *kubeletplugin.Helper pluginCtx context.Context pluginCancel context.CancelCauseFunc @@ -69,6 +79,8 @@ func (d *Driver) Start(ctx context.Context) error { d.pluginCtx = ctx d.pluginCancel = cancel + d.publisher = newCustomPublisher(ctx, d.driverName, d.nodeName, d.poolName, d.kubeClient, d.HandleError) + log.Info("Starting dra plugin") helper, err := kubeletplugin.Start( ctx, @@ -97,6 +109,7 @@ func (d *Driver) Wait() { } func (d *Driver) Shutdown() { + d.publisher.Stop() if d.helper != nil { d.log.Info("Stopping dra plugin") d.helper.Stop() @@ -128,6 +141,7 @@ func (d *Driver) prepareResourceClaim(ctx context.Context, claim *resourcev1.Res preparedPBs, err := d.allocator.Prepare(ctx, claim) if err != nil { + d.log.Error("Error preparing devices for claim", slog.Any("error", err), slog.String("uid", string(claim.UID))) return kubeletplugin.PrepareResult{ Err: fmt.Errorf("error preparing devices for claim %v: %w", claim.UID, err), } @@ -186,7 +200,7 @@ func (d *Driver) startPublisher(ctx context.Context) { return case resources := <-ch: d.log.Info("Publishing devices", slog.Any("resources", resources)) - err := d.helper.PublishResources(ctx, resources) + err := d.publisher.PublishResources(ctx, toWrappedDriverResources(resources)) if err != nil { d.log.Error("Failed to publish devices", slog.Any("err", err)) } @@ -194,3 +208,31 @@ func (d *Driver) startPublisher(ctx context.Context) { } }() } + +func toWrappedDriverResources(resources resourceslice.DriverResources) wrapresourceslice.DriverResources { + wrapped := wrapresourceslice.DriverResources{ + Pools: make(map[string]wrapresourceslice.Pool, len(resources.Pools)), + } + + for name, pool := range resources.Pools { + wrapped.Pools[name] = wrapresourceslice.Pool{ + NodeSelector: pool.NodeSelector, + Generation: pool.Generation, + Slices: toWrappedSlices(pool.Slices), + } + } + + return wrapped +} + +func toWrappedSlices(slices []resourceslice.Slice) []wrapresourceslice.Slice { + wrapped := make([]wrapresourceslice.Slice, len(slices)) + for i, slice := range slices { + wrapped[i] = wrapresourceslice.Slice{ + Devices: slice.Devices, + SharedCounters: slice.SharedCounters, + PerDeviceNodeSelection: slice.PerDeviceNodeSelection, + } + } + return wrapped +} diff --git a/images/virtualization-dra/internal/plugin/manager.go b/images/virtualization-dra/internal/plugin/manager.go index 409da1be3d..563a329f7d 100644 --- a/images/virtualization-dra/internal/plugin/manager.go +++ b/images/virtualization-dra/internal/plugin/manager.go @@ -28,10 +28,10 @@ type Manager struct { checker *HealthCheck } -func NewManager(driverName, nodeName string, kubeClient kubernetes.Interface, allocator Allocator, healthPort int) (*Manager, error) { +func NewManager(driverName, nodeName string, kubeClient kubernetes.Interface, allocator Allocator, healthPort int, shared bool) (*Manager, error) { m := &Manager{} - driver, err := NewDriver(driverName, nodeName, kubeClient, allocator) + driver, err := NewDriver(driverName, nodeName, kubeClient, allocator, shared) if err != nil { return nil, err } diff --git a/images/virtualization-dra/internal/plugin/publish.go b/images/virtualization-dra/internal/plugin/publish.go new file mode 100644 index 0000000000..4b53c70b21 --- /dev/null +++ b/images/virtualization-dra/internal/plugin/publish.go @@ -0,0 +1,134 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugin + +import ( + "context" + "errors" + "fmt" + "sync" + + "k8s.io/client-go/kubernetes" + "k8s.io/dynamic-resource-allocation/kubeletplugin" + "k8s.io/klog/v2" + + "github.com/deckhouse/virtualization-dra/internal/plugin/wrapresourceslice" +) + +type resourcePublisher interface { + PublishResources(ctx context.Context, resources wrapresourceslice.DriverResources) error + Stop() +} +type errorHandler func(ctx context.Context, err error, msg string) + +func newCustomPublisher(ctx context.Context, driverName, nodeName, poolName string, kubeClient kubernetes.Interface, errorHandler errorHandler) resourcePublisher { + ctx, cancel := context.WithCancelCause(ctx) + return &customPublisher{ + driverName: driverName, + nodeName: nodeName, + poolName: poolName, + kubeClient: kubeClient, + errorHandler: errorHandler, + backgroundCtx: ctx, + cancel: cancel, + } +} + +type customPublisher struct { + driverName string + nodeName string + poolName string + kubeClient kubernetes.Interface + backgroundCtx context.Context + cancel func(cause error) + errorHandler errorHandler + + mutex sync.Mutex + resourceSliceController *wrapresourceslice.Controller +} + +func (p *customPublisher) PublishResources(_ context.Context, resources wrapresourceslice.DriverResources) error { + if p.nodeName == "" { + return errors.New("no NodeName was set to publish resources") + } + + p.mutex.Lock() + defer p.mutex.Unlock() + + owner := wrapresourceslice.Owner{ + APIVersion: "v1", + Kind: "Node", + Name: p.nodeName, + } + driverResources := &wrapresourceslice.DriverResources{ + Pools: resources.Pools, + } + + if p.resourceSliceController == nil { + // Start publishing the information. The controller is using + // our background context, not the one passed into this + // function, and thus is connected to the lifecycle of the + // plugin. + controllerCtx := p.backgroundCtx + //nolint:contextcheck // copied from dra helper + controllerLogger := klog.FromContext(controllerCtx) + controllerLogger = klog.LoggerWithName(controllerLogger, "ResourceSlice controller") + controllerCtx = klog.NewContext(controllerCtx, controllerLogger) + var err error + //nolint:contextcheck // copied from dra helper + if p.resourceSliceController, err = wrapresourceslice.StartController(controllerCtx, + wrapresourceslice.Options{ + DriverName: p.driverName, + KubeClient: p.kubeClient, + Owner: &owner, + Resources: driverResources, + ErrorHandler: func(ctx context.Context, err error, msg string) { + // ResourceSlice publishing errors like dropped fields or + // invalid spec are not going to get resolved by retrying, + // but neither is restarting the process going to help + // -> all errors are recoverable. + p.errorHandler(ctx, recoverableError{error: err}, msg) + }, + ReconcileOnlyPoolName: p.poolName, + }); err != nil { + return fmt.Errorf("start ResourceSlice controller: %w", err) + } + return nil + } + // Inform running controller about new information. + if err := p.resourceSliceController.Update(driverResources); err != nil { + return fmt.Errorf("failed to update ResourceSlice controller: %w", err) + } + + return nil +} + +func (p *customPublisher) Stop() { + if p == nil { + return + } + p.cancel(errors.New("customPublisher was stopped")) +} + +type recoverableError struct { + error +} + +var _ error = recoverableError{} + +func (err recoverableError) Is(other error) bool { return other == kubeletplugin.ErrRecoverable } +func (err recoverableError) Unwrap() error { return err.error } diff --git a/images/virtualization-dra/internal/plugin/wrapresourceslice/resourceslicecontroller.go b/images/virtualization-dra/internal/plugin/wrapresourceslice/resourceslicecontroller.go new file mode 100644 index 0000000000..9f42c0f6dc --- /dev/null +++ b/images/virtualization-dra/internal/plugin/wrapresourceslice/resourceslicecontroller.go @@ -0,0 +1,1096 @@ +/* +Copyright 2024 The Kubernetes Authors. +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Initially copied from https://github.com/kubernetes/dynamic-resource-allocation/blob/v0.35.1/resourceslice/resourceslicecontroller.go#L177 +*/ + +// TODO: https://github.com/kubernetes/kubernetes/issues/137011 +// Delete this file once the issue is fixed. + +package wrapresourceslice + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + corev1 "k8s.io/api/core/v1" + resourcev1 "k8s.io/api/resource/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/conversion" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/diff" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/sets" + watch "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes" + cgocore "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" + draclient "k8s.io/dynamic-resource-allocation/client" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" +) + +const ( + // poolNameIndex is the name for the ResourceSlice store's index function, + // which is to index by ResourceSlice.Spec.Pool.Name + poolNameIndex = "poolName" + + // Including adds in the mutation cache is not safe: We could add a slice, store it, + // and then the slice gets deleted without the informer hearing anything about that. + // Then the obsolete slice remains in the mutation cache. + // + // To mitigate this, we use a TTL and check a pool again once added slices expire. + DefaultMutationCacheTTL = time.Minute + + // DefaultSyncDelay defines how long to wait between receiving the most recent + // informer event and syncing again. This is long enough that the informer cache + // should be up-to-date (matters mostly for deletes because an out-dated cache + // causes redundant delete API calls) and not too long that a human mistake + // doesn't get fixed while that human is waiting for it. + DefaultSyncDelay = 30 * time.Second +) + +// Controller synchronizes information about resources of one driver with +// ResourceSlice objects. It supports node-local and network-attached +// resources. A DRA driver for node-local resources typically runs this +// controller as part of its kubelet plugin. +type Controller struct { + cancel func(cause error) + driverName string + owner *Owner + resourceClient *draclient.Client + coreClient cgocore.CoreV1Interface + wg sync.WaitGroup + // The queue is keyed with the pool name that needs work. + queue workqueue.TypedRateLimitingInterface[string] + sliceStore cache.MutationCache + mutationCacheTTL time.Duration + syncDelay time.Duration + errorHandler func(ctx context.Context, err error, msg string) + + // Last time that a ResourceSlice of a pool was created. + // At that time + cache mutation TTL do we have to sync again + // because the locally cached slice might have stayed in the + // cache erronously (not removed on delete by someone else) + // and we have to check again once it has been removed from + // the cache. + // + // It's not sufficient to schedule a delayed sync because + // another sync scheduled by an event overwrites the older + // one, so we would sync too soon and then not again. + // + // The key is the pool name. This makes each time entry + // unique for syncPool calls for the pool. + lastAddByPool map[string]time.Time + + // Must use atomic access... + numCreates int64 + numUpdates int64 + numDeletes int64 + + mutex sync.RWMutex + + // When receiving updates from the driver, the entire pointer replaced, + // so it is okay to not do a deep copy of it when reading it. Only reading + // the pointer itself must be protected by a read lock. + resources *DriverResources + + reconcileOnlyPoolName string +} + +// +k8s:deepcopy-gen=true + +// DriverResources is a complete description of all resources synchronized by the controller. +type DriverResources struct { + // Each driver may manage different resource pools. + Pools map[string]Pool +} + +// +k8s:deepcopy-gen=true + +// Pool is the collection of devices belonging to the same pool. +type Pool struct { + // NodeSelector may be different for each pool. Must not get set together + // with Resources.NodeName. If nil and Resources.NodeName is not set, + // then devices are available on all nodes. + NodeSelector *corev1.NodeSelector + + // Generation can be left at zero. It gets bumped up automatically + // by the controller. + Generation int64 + + // Slices is a list of all ResourceSlices that the driver + // wants to publish for this pool. The driver must ensure + // that each resulting slice is valid. See the API + // definition for details, in particular the limit on + // the number of devices. + // + // If slices are not valid, then the controller will + // log errors produced by the apiserver. + // + // Drivers should publish at least one slice for each + // pool that they normally manage, even if that slice + // is empty. "Empty pool" is different from "no pool" + // because it shows that the driver is up-and-running + // and simply doesn't have any devices. + Slices []Slice +} + +// +k8s:deepcopy-gen=true + +// Slice is turned into one ResourceSlice by the controller. +type Slice struct { + // Devices lists all devices which are part of the slice. + Devices []resourcev1.Device + SharedCounters []resourcev1.CounterSet + PerDeviceNodeSelection *bool +} + +// +k8s:deepcopy-gen=true + +// Owner is the resource which is meant to be listed as owner of the resource slices. +// For a node the UID may be left blank. The controller will look it up automatically. +type Owner struct { + APIVersion string + Kind string + Name string + UID types.UID +} + +// StartController constructs a new controller and starts it. +func StartController(ctx context.Context, options Options) (*Controller, error) { + logger := klog.FromContext(ctx) + c, err := newController(ctx, options) + if err != nil { + return nil, fmt.Errorf("create controller: %w", err) + } + + logger.V(3).Info("Starting") + c.wg.Add(1) + go func() { + defer c.wg.Done() + defer logger.V(3).Info("Stopping") + c.run(ctx) + }() + return c, nil +} + +// Options contains various optional settings for [StartController]. +type Options struct { + // DriverName is the required name of the DRA driver. + DriverName string + + // KubeClient is used to read Node objects (if necessary) and to access + // ResourceSlices. It must be specified. + KubeClient kubernetes.Interface + + // If the owner is a v1.Node, then the NodeName field in the + // ResourceSlice objects is set and used to identify objects + // managed by the controller. The UID is not needed in that + // case, the controller will determine it automatically. + // + // The owner must be cluster-scoped. This is not always possible, + // therefore it is optional. A driver without a owner must take + // care that remaining slices get deleted manually as part of + // a driver uninstall because garbage collection won't work. + Owner *Owner + + // This is the initial desired set of slices. Nil means "no resources". + Resources *DriverResources + + // Queue can be used to override the default work queue implementation. + Queue workqueue.TypedRateLimitingInterface[string] + + // MutationCacheTTL can be used to change the default TTL of one minute. + // See source code for details. + MutationCacheTTL *time.Duration + + // SyncDelay defines how long to wait between receiving the most recent + // informer event and syncing again. The default is 30 seconds. + // + // This is long enough that the informer cache should be up-to-date + // (matter mostly for deletes because an out-dated cache causes + // redundant delete API calls) and not too long that a human mistake + // doesn't get fixed while that human is waiting for it. + SyncDelay *time.Duration + + // ErrorHandler will get called whenever the controller encounters + // a problem while trying to publish ResourceSlices. The controller + // will retry once the handler returns. What the handler does with + // that information is up to the handler. It could log the error, + // replace the slices if they cannot be published (see below), + // or force the program running the controller to fail by exiting. + // + // If some fields were dropped because the cluster does not support + // the feature they depend on, then the error is or wraps an + // [DroppedFieldsError] instance. Use [errors.As] to convert to that + // type: + // var droppedFields *resourceslice.DroppedFieldsError + // if errors.As(err, &droppedFields) { ... do something with droppedFields ... } + // + // The default is [utilruntime.HandleErrorWithContext] which just logs + // the problem. + ErrorHandler func(ctx context.Context, err error, msg string) + + ReconcileOnlyPoolName string +} + +// DroppedFieldsError is reported through the ErrorHandler in [Options] if +// a slice could not be published exactly as desired by the driver. +type DroppedFieldsError struct { + PoolName string + SliceIndex int + DesiredSlice, ActualSlice *resourcev1.ResourceSlice +} + +func (err *DroppedFieldsError) Error() string { + // We cannot depend on go-cmp to include a diff here (not suitable for production code). + // The diff might be too large, too. But we can make some educated guesses.... + disabled := err.DisabledFeatures() + if len(disabled) == 0 { + // If we get here, DisabledFeatures needs to be updated. + disabled = []string{"unknown"} + } + return fmt.Sprintf("pool %q, slice #%d: some fields were dropped by the apiserver, probably because these features are disabled: %s", err.PoolName, err.SliceIndex, strings.Join(disabled, " ")) +} + +func (err *DroppedFieldsError) DisabledFeatures() []string { + var disabled []string + + // Both slices should have the same number of devices, but better check it. + for i := 0; i < len(err.DesiredSlice.Spec.Devices) && i < len(err.ActualSlice.Spec.Devices); i++ { + if len(err.DesiredSlice.Spec.Devices[i].Taints) > len(err.ActualSlice.Spec.Devices[i].Taints) { + disabled = append(disabled, "DRADeviceTaints") + break + } + } + + // Dropped fields for partitionable devices can be detected without looking at the devices themselves. + if ptr.Deref(err.DesiredSlice.Spec.PerDeviceNodeSelection, false) && !ptr.Deref(err.ActualSlice.Spec.PerDeviceNodeSelection, false) || + len(err.DesiredSlice.Spec.SharedCounters) > len(err.ActualSlice.Spec.SharedCounters) { + disabled = append(disabled, "DRAPartitionableDevices") + } + + // The number of binding conditions for both slices should be the same. If they differ, + // it indicates that the DRADeviceBindingConditions feature is disabled. + for i := 0; i < len(err.DesiredSlice.Spec.Devices) && i < len(err.ActualSlice.Spec.Devices); i++ { + if len(err.DesiredSlice.Spec.Devices[i].BindingConditions) != len(err.ActualSlice.Spec.Devices[i].BindingConditions) || + len(err.DesiredSlice.Spec.Devices[i].BindingFailureConditions) != len(err.ActualSlice.Spec.Devices[i].BindingFailureConditions) { + disabled = append(disabled, "DRADeviceBindingConditions") + break + } + } + + // Dropped fields for consumable capacity can be detected with allowMultipleAllocations flag without looking at individual device capacity. + for i := 0; i < len(err.DesiredSlice.Spec.Devices) && i < len(err.ActualSlice.Spec.Devices); i++ { + if err.DesiredSlice.Spec.Devices[i].AllowMultipleAllocations != nil && err.ActualSlice.Spec.Devices[i].AllowMultipleAllocations == nil { + disabled = append(disabled, "DRAConsumableCapacity") + break + } + } + + return disabled +} + +var _ error = &DroppedFieldsError{} + +// Stop cancels all background activity and blocks until the controller has stopped. +func (c *Controller) Stop() { + if c == nil { + return + } + c.cancel(errors.New("ResourceSlice controller was asked to stop")) + c.wg.Wait() +} + +// Update sets the new desired state of the resource information. +// +// The controller is doing a deep copy, so the caller may update +// the instance once Update returns. Nil is valid and the same +// as an empty resources struct. +func (c *Controller) Update(resources *DriverResources) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Sync all old pools.. + if c.resources != nil { + for poolName := range c.resources.Pools { + c.queue.Add(poolName) + } + } + + if resources == nil { + c.resources = &DriverResources{} + } else { + if c.reconcileOnlyPoolName != "" { + _, ok := resources.Pools[c.reconcileOnlyPoolName] + if !ok && len(resources.Pools) > 1 { + return fmt.Errorf("reconcileOnlyPoolName is set to %q, but multiple pools found (%d total)", + c.reconcileOnlyPoolName, len(resources.Pools)) + } + } + + c.resources = resources.DeepCopy() + roundTaintTimeAdded(c.resources) + } + + // ... and the new ones (might be the same). + for poolName := range c.resources.Pools { + c.queue.Add(poolName) + } + + return nil +} + +// roundTaintTimeAdded rounds all timestamps to seconds because that is all +// that we can store. Without this we would get semantic differences between +// desired and actual stored slice. +func roundTaintTimeAdded(resources *DriverResources) { + for _, pool := range resources.Pools { + for _, slice := range pool.Slices { + for _, device := range slice.Devices { + for _, taint := range device.Taints { + if taint.TimeAdded != nil { + taint.TimeAdded.Time = taint.TimeAdded.Round(time.Second) + } + } + } + } + } +} + +// GetStats provides some insights into operations of the controller. +func (c *Controller) GetStats() Stats { + s := Stats{ + NumCreates: atomic.LoadInt64(&c.numCreates), + NumUpdates: atomic.LoadInt64(&c.numUpdates), + NumDeletes: atomic.LoadInt64(&c.numDeletes), + } + return s +} + +type Stats struct { + // NumCreates counts the number of ResourceSlices that got created. + NumCreates int64 + // NumUpdates counts the number of ResourceSlices that got update. + NumUpdates int64 + // NumDeletes counts the number of ResourceSlices that got deleted. + NumDeletes int64 +} + +// newController creates a new controller. +func newController(ctx context.Context, options Options) (*Controller, error) { + if options.KubeClient == nil { + return nil, errors.New("KubeClient is nil") + } + if options.DriverName == "" { + return nil, errors.New("DRA driver name is empty") + } + + ctx, cancel := context.WithCancelCause(ctx) + + c := &Controller{ + cancel: cancel, + resourceClient: draclient.New(options.KubeClient), + coreClient: options.KubeClient.CoreV1(), + driverName: options.DriverName, + owner: options.Owner.DeepCopy(), + queue: options.Queue, + mutationCacheTTL: ptr.Deref(options.MutationCacheTTL, DefaultMutationCacheTTL), + syncDelay: ptr.Deref(options.SyncDelay, DefaultSyncDelay), + errorHandler: options.ErrorHandler, + lastAddByPool: make(map[string]time.Time), + reconcileOnlyPoolName: options.ReconcileOnlyPoolName, + } + if c.queue == nil { + c.queue = workqueue.NewTypedRateLimitingQueueWithConfig( + workqueue.DefaultTypedControllerRateLimiter[string](), + workqueue.TypedRateLimitingQueueConfig[string]{Name: "node_resource_slices"}, + ) + } + if c.errorHandler == nil { + c.errorHandler = func(ctx context.Context, err error, msg string) { + utilruntime.HandleErrorWithContext(ctx, err, msg) + } + } + if err := c.initInformer(ctx); err != nil { + return nil, err + } + + if err := c.Update(options.Resources); err != nil { + return nil, fmt.Errorf("failed to update resources: %w", err) + } + + return c, nil +} + +// initInformer initializes the informer used to watch for changes to the resources slice. +func (c *Controller) initInformer(ctx context.Context) error { + logger := klog.FromContext(ctx) + + // We always filter by driver name, by node name only for node-local resources. + selector := fields.Set{ + resourcev1.ResourceSliceSelectorDriver: c.driverName, + resourcev1.ResourceSliceSelectorNodeName: "", + } + + if c.owner != nil && c.owner.APIVersion == "v1" && c.owner.Kind == "Node" && c.reconcileOnlyPoolName == "" { + selector[resourcev1.ResourceSliceSelectorNodeName] = c.owner.Name + } + + tweakListOptions := func(options *metav1.ListOptions) { + options.FieldSelector = selector.String() + } + indexers := cache.Indexers{ + poolNameIndex: func(obj interface{}) ([]string, error) { + slice, ok := obj.(*resourcev1.ResourceSlice) + if !ok { + return []string{}, nil + } + return []string{slice.Spec.Pool.Name}, nil + }, + } + informer := cache.NewSharedIndexInformer( + &cache.ListWatch{ + ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) { + tweakListOptions(&options) + slices, err := c.resourceClient.ResourceSlices().List(ctx, options) + if err == nil { + logger.V(5).Info("Listed ResourceSlices", "resourceAPI", c.resourceClient.CurrentAPI(), "numSlices", len(slices.Items), "listMeta", slices.ListMeta) + } else { + logger.V(5).Info("Listed ResourceSlices", "resourceAPI", c.resourceClient.CurrentAPI(), "err", err) + } + + if c.reconcileOnlyPoolName != "" { + for i := range slices.Items { + if slices.Items[i].Spec.Pool.Name == c.reconcileOnlyPoolName { + slices.Items = []resourcev1.ResourceSlice{slices.Items[i]} + return slices, nil + } + } + slices.Items = nil + return slices, nil + } + + return slices, err + }, + WatchFuncWithContext: func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) { + tweakListOptions(&options) + w, err := c.resourceClient.ResourceSlices().Watch(ctx, options) + logger.V(5).Info("Started watching ResourceSlices", "resourceAPI", c.resourceClient.CurrentAPI(), "err", err) + + if c.reconcileOnlyPoolName != "" { + return newWrapWatcher(ctx, w, func(event watch.Event) bool { + resourceSlice, ok := event.Object.(*resourcev1.ResourceSlice) + return ok && resourceSlice.Spec.Pool.Name == c.reconcileOnlyPoolName + }), nil + } + + return w, err + }, + }, + &resourcev1.ResourceSlice{}, + // No resync because all it would do is periodically trigger syncing pools + // again by reporting all slices as updated with the object as old/new. + // Our sync method is deterministic (or should be!), so repeating it + // won't change the outcome. + 0, + indexers, + ) + c.sliceStore = cache.NewIntegerResourceVersionMutationCache(logger, informer.GetStore(), informer.GetIndexer(), c.mutationCacheTTL, true /* includeAdds */) + handler, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj any) { + slice, ok := obj.(*resourcev1.ResourceSlice) + if !ok { + return + } + logger.V(5).Info("ResourceSlice add", "slice", klog.KObj(slice)) + c.queue.AddAfter(slice.Spec.Pool.Name, c.syncDelay) + logger.V(5).Info("Scheduled sync", "poolName", slice.Spec.Pool.Name, "at", time.Now().Add(c.syncDelay)) + }, + UpdateFunc: func(old, new any) { + oldSlice, ok := old.(*resourcev1.ResourceSlice) + if !ok { + return + } + newSlice, ok := new.(*resourcev1.ResourceSlice) + if !ok { + return + } + if loggerV := logger.V(6); loggerV.Enabled() { + loggerV.Info("ResourceSlice update", "slice", klog.KObj(newSlice), "diff", diff.Diff(oldSlice, newSlice)) + } else { + logger.V(5).Info("ResourceSlice update", "slice", klog.KObj(newSlice)) + } + c.queue.AddAfter(oldSlice.Spec.Pool.Name, c.syncDelay) + logger.V(5).Info("Scheduled sync", "pool", oldSlice.Spec.Pool.Name, "at", time.Now().Add(c.syncDelay)) + if oldSlice.Spec.Pool.Name != newSlice.Spec.Pool.Name { + c.queue.AddAfter(newSlice.Spec.Pool.Name, c.syncDelay) + logger.V(5).Info("Scheduled sync", "poolName", newSlice.Spec.Pool.Name, "at", time.Now().Add(c.syncDelay)) + } + }, + DeleteFunc: func(obj any) { + if tombstone, ok := obj.(cache.DeletedFinalStateUnknown); ok { + obj = tombstone.Obj + } + slice, ok := obj.(*resourcev1.ResourceSlice) + if !ok { + return + } + logger.V(5).Info("ResourceSlice delete", "slice", klog.KObj(slice)) + c.queue.AddAfter(slice.Spec.Pool.Name, c.syncDelay) + logger.V(5).Info("Scheduled sync", "poolName", slice.Spec.Pool.Name, "at", time.Now().Add(c.syncDelay)) + }, + }) + if err != nil { + return fmt.Errorf("registering event handler on the ResourceSlice informer: %w", err) + } + // Start informer and wait for our cache to be populated. + logger.V(3).Info("Starting ResourceSlice informer and waiting for it to sync") + c.wg.Add(1) + go func() { + defer c.wg.Done() + defer logger.V(3).Info("ResourceSlice informer has stopped") + defer c.queue.ShutDown() // Once we get here, we must have been asked to stop. + informer.Run(ctx.Done()) + }() + for !handler.HasSynced() { + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return fmt.Errorf("sync ResourceSlice informer: %w", context.Cause(ctx)) + } + } + logger.V(3).Info("ResourceSlice informer has synced") + return nil +} + +// run is running in the background. +func (c *Controller) run(ctx context.Context) { + for c.processNextWorkItem(ctx) { + } +} + +func (c *Controller) processNextWorkItem(ctx context.Context) bool { + poolName, shutdown := c.queue.Get() + if shutdown { + return false + } + defer c.queue.Done(poolName) + logger := klog.FromContext(ctx) + + err := c.syncPool(klog.NewContext(ctx, klog.LoggerWithValues(logger, "poolName", poolName)), poolName) + if err != nil { + c.errorHandler(ctx, err, "processing ResourceSlice objects") + c.queue.AddRateLimited(poolName) + + // Return without removing the work item from the queue. + // It will be retried. + return true + } + + c.queue.Forget(poolName) + return true +} + +// syncPool processes one pool. Only runs inside a single worker, so there +// is no need for locking except when accessing c.resources, which may +// be updated at any time by the user of the controller. +func (c *Controller) syncPool(ctx context.Context, poolName string) error { + logger := klog.FromContext(ctx) + start := time.Now() + + // Gather information about the actual and desired state. + var slices []*resourcev1.ResourceSlice + objs, err := c.sliceStore.ByIndex(poolNameIndex, poolName) + if err != nil { + return fmt.Errorf("retrieve ResourceSlice objects: %w", err) + } + for _, obj := range objs { + if slice, ok := obj.(*resourcev1.ResourceSlice); ok { + slices = append(slices, slice) + } + } + var resources *DriverResources + c.mutex.RLock() + resources = c.resources + c.mutex.RUnlock() + + pool, ok := resources.Pools[poolName] + if !ok { + if len(slices) > 0 { + // All are obsolete, pool does not exist anymore. + logger.V(5).Info("Removing resource slices after pool removal") + if err := c.removeSlices(ctx, slices); err != nil { + return fmt.Errorf("remove slices: %w", err) + } + } + // Pool does not exist anymore, nothing more to do. + return nil + } + + // Retrieve node object to get UID? + // The result gets cached and is expected to not change while + // the controller runs. + var nodeName string + if c.owner != nil && c.owner.APIVersion == "v1" && c.owner.Kind == "Node" { + if c.reconcileOnlyPoolName == "" { + nodeName = c.owner.Name + } + if c.owner.UID == "" { + node, err := c.coreClient.Nodes().Get(ctx, c.owner.Name, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("retrieve node %q: %w", c.owner.Name, err) + } + // There is only one worker, so no locking needed. + c.owner.UID = node.UID + } + } + // Slices that don't match any driver slice need to be deleted. + obsoleteSlices := make([]*resourcev1.ResourceSlice, 0, len(slices)) + + // Determine highest generation. + var generation int64 + for _, slice := range slices { + if slice.Spec.Pool.Generation > generation { + generation = slice.Spec.Pool.Generation + } + } + + // Everything older is obsolete. + currentSlices := make([]*resourcev1.ResourceSlice, 0, len(slices)) + for _, slice := range slices { + if slice.Spec.Pool.Generation < generation { + obsoleteSlices = append(obsoleteSlices, slice) + } else { + currentSlices = append(currentSlices, slice) + } + } + logger.V(5).Info("Existing slices", "obsolete", klog.KObjSlice(obsoleteSlices), "current", klog.KObjSlice(currentSlices)) + + // Match each existing slice against the desired slices. + // Two slices "match" if they contain exactly the + // same device IDs, in an arbitrary order. As a + // special case, slices are also considered + // "matched" in the scenario where there's a single + // existing slice and a single desired slice. Such a + // matched slice gets updated with the desired + // content if there is a difference. + // + // In the case where there is more than one existing + // or desired slices, adding or removing devices is + // done by deleting the old slice and creating a new one. + // + // This is primarily a simplification of the code: + // to support adding or removing devices from + // existing slices, we would have to identify "most + // similar" slices (= minimal editing distance). + // + // In currentSliceForDesiredSlice we keep track of + // which desired slice has a matched slice. + // + // At the end of the loop, each current slice is either + // a match or obsolete. + currentSliceForDesiredSlice := make(map[int]*resourcev1.ResourceSlice, len(pool.Slices)) + if len(currentSlices) == 1 && len(pool.Slices) == 1 { + // If there's just one existing slice and one desired slice, assume + // they "matched" such that if required, it is the existing slice + // which gets updated and we avoid an unnecessary deletion and + // recreation of the slice. + currentSliceForDesiredSlice[0] = currentSlices[0] + } else { + for _, currentSlice := range currentSlices { + matched := false + for i := range pool.Slices { + if _, ok := currentSliceForDesiredSlice[i]; ok { + // Already has a match. + continue + } + if sameSlice(currentSlice, &pool.Slices[i]) { + currentSliceForDesiredSlice[i] = currentSlice + logger.V(5).Info("Matched existing slice", "slice", klog.KObj(currentSlice), "matchIndex", i) + matched = true + break + } + } + if !matched { + obsoleteSlices = append(obsoleteSlices, currentSlice) + logger.V(5).Info("Unmatched existing slice", "slice", klog.KObj(currentSlice)) + } + } + } + + // Desired metadata which must be set in each slice. + resourceSliceCount := len(pool.Slices) + numMatchedSlices := len(currentSliceForDesiredSlice) + numNewSlices := resourceSliceCount - numMatchedSlices + desiredPool := resourcev1.ResourcePool{ + Name: poolName, + Generation: generation, // May get updated later. + ResourceSliceCount: int64(resourceSliceCount), + } + desiredAllNodes := pool.NodeSelector == nil && nodeName == "" + + // Now for each desired slice, figure out which of them are changed. + changedDesiredSlices := sets.New[int]() + for i, currentSlice := range currentSliceForDesiredSlice { + // Reordering entries is a difference and causes an update even if the + // entries are the same. + if !apiequality.Semantic.DeepEqual(¤tSlice.Spec.Pool, &desiredPool) || + !apiequality.Semantic.DeepEqual(currentSlice.Spec.NodeSelector, pool.NodeSelector) || + ptr.Deref(currentSlice.Spec.AllNodes, false) != desiredAllNodes || + !DevicesDeepEqual(currentSlice.Spec.Devices, pool.Slices[i].Devices) || + !apiequality.Semantic.DeepEqual(currentSlice.Spec.SharedCounters, pool.Slices[i].SharedCounters) || + !apiequality.Semantic.DeepEqual(currentSlice.Spec.PerDeviceNodeSelection, pool.Slices[i].PerDeviceNodeSelection) { + changedDesiredSlices.Insert(i) + logger.V(5).Info("Need to update slice", "slice", klog.KObj(currentSlice), "matchIndex", i) + } + } + logger.V(5).Info("Completed comparison", + "numObsolete", len(obsoleteSlices), + "numMatchedSlices", len(currentSliceForDesiredSlice), + "numChangedMatchedSlices", len(changedDesiredSlices), + "numNewSlices", numNewSlices, + ) + + bumpedGeneration := false + switch { + case pool.Generation > generation: + // Bump up the generation if the driver asked for it, or + // start with a non-zero generation. + generation = pool.Generation + bumpedGeneration = true + logger.V(5).Info("Bumped generation to driver-provided generation", "generation", generation) + case numNewSlices == 0 && len(changedDesiredSlices) <= 1: + logger.V(5).Info("Kept generation because at most one update API call is necessary", "generation", generation) + default: + generation++ + bumpedGeneration = true + logger.V(5).Info("Bumped generation by one", "generation", generation) + } + desiredPool.Generation = generation + + // First delete obsolete slices. If the desired slices are faulty, then it's still better to + // remove devices that the driver no longer has, even if we cannot publish the new ones. + if err := c.removeSlices(ctx, obsoleteSlices); err != nil { + return fmt.Errorf("remove slices: %w", err) + } + + // Update existing slices. + for i, currentSlice := range currentSliceForDesiredSlice { + if !changedDesiredSlices.Has(i) && !bumpedGeneration { + continue + } + slice := currentSlice.DeepCopy() + slice.Spec.Pool = desiredPool + // No need to set the node name. If it was different, we wouldn't + // have listed the existing slice. + // + // When adding new fields here, then also extend sliceStored. + slice.Spec.NodeSelector = pool.NodeSelector + slice.Spec.AllNodes = refIfNotZero(desiredAllNodes) + slice.Spec.SharedCounters = pool.Slices[i].SharedCounters + slice.Spec.PerDeviceNodeSelection = pool.Slices[i].PerDeviceNodeSelection + // Preserve TimeAdded from existing device, if there is a matching device and taint. + slice.Spec.Devices = copyTaintTimeAdded(slice.Spec.Devices, pool.Slices[i].Devices) + + actualSlice, err := c.resourceClient.ResourceSlices().Update(ctx, slice, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("update resource slice: %w", err) + } + logger.V(5).Info("Updated existing resource slice", "slice", klog.KObj(slice)) + atomic.AddInt64(&c.numUpdates, 1) + c.sliceStored(ctx, "update ResourceSlice", poolName, pool, i, slice, actualSlice) + } + + // Create new slices. + added := false + for i := 0; i < len(pool.Slices); i++ { + if _, ok := currentSliceForDesiredSlice[i]; ok { + // Was handled above through an update. + continue + } + var ownerReferences []metav1.OwnerReference + if c.owner != nil { + ownerReferences = append(ownerReferences, + metav1.OwnerReference{ + APIVersion: c.owner.APIVersion, + Kind: c.owner.Kind, + Name: c.owner.Name, + UID: c.owner.UID, + Controller: ptr.To(true), + }, + ) + } + generateName := c.driverName + "-" + if c.owner != nil { + generateName = c.owner.Name + "-" + generateName + } + slice := &resourcev1.ResourceSlice{ + ObjectMeta: metav1.ObjectMeta{ + OwnerReferences: ownerReferences, + GenerateName: generateName, + }, + Spec: resourcev1.ResourceSliceSpec{ + Driver: c.driverName, + Pool: desiredPool, + NodeName: refIfNotZero(nodeName), + NodeSelector: pool.NodeSelector, + AllNodes: refIfNotZero(desiredAllNodes), + Devices: pool.Slices[i].Devices, + SharedCounters: pool.Slices[i].SharedCounters, + PerDeviceNodeSelection: pool.Slices[i].PerDeviceNodeSelection, + }, + } + + // It can happen that we create a missing slice, some + // other change than the create causes another sync of + // the pool, and then a second slice for the same set + // of devices would get created because the controller has + // no copy of the first slice instance in its informer + // cache yet. + // + // Using a https://pkg.go.dev/k8s.io/client-go/tools/cache#MutationCache + // avoids that. + actualSlice, err := c.resourceClient.ResourceSlices().Create(ctx, slice, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("create resource slice: %w", err) + } + logger.V(5).Info("Created new resource slice", "slice", klog.KObj(actualSlice)) + atomic.AddInt64(&c.numCreates, 1) + added = true + c.sliceStored(ctx, "create ResourceSlice", poolName, pool, i, slice, actualSlice) + } + + now := time.Now() + if added { + c.lastAddByPool[poolName] = now + logger.V(5).Info("Added slices") + } else if lastAdd, ok := c.lastAddByPool[poolName]; ok && start.After(lastAdd.Add(c.mutationCacheTTL)) { + // This sync started after the last add expired from the cache, + // so we are done and don't need to check again. + delete(c.lastAddByPool, poolName) + logger.V(5).Info("Done with re-syncing") + } + if lastAdd, ok := c.lastAddByPool[poolName]; ok { + // Need to check again. + // + // Scheduling the resync races with scheduling them in informer events, but that's okay: + // what matters is that we sync at all at some point. + // + // lastAdd was taked by time.Now() above and thus is slightly higher or equal + // to the time taken by the mutation cache when the slice was added, so we + // can be sure that any sync running at this time will not see the added + // slice because it will be expired. + when := lastAdd.Add(c.mutationCacheTTL) + c.queue.AddAfter(poolName, when.Sub(now)) + logger.V(5).Info("Scheduled re-sync", "at", when) + } + + return nil +} + +func (c *Controller) removeSlices(ctx context.Context, slices []*resourcev1.ResourceSlice) error { + logger := klog.FromContext(ctx) + + for _, slice := range slices { + options := metav1.DeleteOptions{ + Preconditions: &metav1.Preconditions{ + UID: &slice.UID, + ResourceVersion: &slice.ResourceVersion, + }, + } + // It can happen that we sync again shortly after deleting a + // slice and before the slice gets removed from the informer + // cache. The MutationCache can't help here because it does not + // track pending deletes. + // + // If this happens, we get a "not found error" and nothing + // changes on the server. The only downside is the extra API + // call. This isn't as bad as extra creates. + err := c.resourceClient.ResourceSlices().Delete(ctx, slice.Name, options) + switch { + case err == nil: + logger.V(5).Info("Deleted obsolete resource slice", "slice", klog.KObj(slice), "deleteOptions", options) + atomic.AddInt64(&c.numDeletes, 1) + case apierrors.IsNotFound(err): + logger.V(5).Info("Resource slice was already deleted earlier", "slice", klog.KObj(slice)) + default: + return fmt.Errorf("delete resource slice: %w", err) + } + } + + return nil +} + +// sliceStored gets called after creating or updating a slice. +// The slice might have been modified during the roundtrip +// through the apiserver. +func (c *Controller) sliceStored(ctx context.Context, msg, poolName string, pool Pool, sliceIndex int, desiredSlice, actualSlice *resourcev1.ResourceSlice) { + c.sliceStore.Mutation(actualSlice) + + // One difference is normal: the apiserver may have added TimeAdded to taints. + // This mutates desiredSlice for the DeepEqual below. + if copyServerDefaults(desiredSlice, actualSlice) { + pool.Slices[sliceIndex].Devices = actualSlice.Spec.Devices + } + + // Some fields may have been dropped. When we receive + // the updated slice through the informer, the + // DeepEqual fails and the controller would try to + // update again, etc. To break that cycle, update our + // desired state of the world so that it matches what + // we can store. + if !apiequality.Semantic.DeepEqual(desiredSlice.Spec.PerDeviceNodeSelection, actualSlice.Spec.PerDeviceNodeSelection) || + !apiequality.Semantic.DeepEqual(desiredSlice.Spec.SharedCounters, actualSlice.Spec.SharedCounters) || + !apiequality.Semantic.DeepEqual(desiredSlice.Spec.Devices, actualSlice.Spec.Devices) { + pool.Slices[sliceIndex].PerDeviceNodeSelection = actualSlice.Spec.PerDeviceNodeSelection + pool.Slices[sliceIndex].SharedCounters = actualSlice.Spec.SharedCounters + pool.Slices[sliceIndex].Devices = actualSlice.Spec.Devices + + err := &DroppedFieldsError{ + PoolName: poolName, + SliceIndex: sliceIndex, + DesiredSlice: desiredSlice.DeepCopy(), + ActualSlice: actualSlice.DeepCopy(), + } + c.errorHandler(ctx, err, msg) + } +} + +func copyServerDefaults(desiredSlice, actualSlice *resourcev1.ResourceSlice) bool { + copied := false + + // Should have the same length and entries in the same order. + for i := 0; i < len(desiredSlice.Spec.Devices) && i < len(actualSlice.Spec.Devices); i++ { + for e := 0; e < len(desiredSlice.Spec.Devices[i].Taints) && e < len(actualSlice.Spec.Devices[i].Taints); e++ { + if desiredSlice.Spec.Devices[i].Taints[e].TimeAdded == nil && actualSlice.Spec.Devices[i].Taints[e].TimeAdded != nil { + if !copied { + desiredSlice.Spec = *desiredSlice.Spec.DeepCopy() + copied = true + } + desiredSlice.Spec.Devices[i].Taints[e].TimeAdded = actualSlice.Spec.Devices[i].Taints[e].TimeAdded + } + } + } + return copied +} + +func sameSlice(existingSlice *resourcev1.ResourceSlice, desiredSlice *Slice) bool { + if len(existingSlice.Spec.Devices) != len(desiredSlice.Devices) { + return false + } + + existingDevices := sets.New[string]() + for _, device := range existingSlice.Spec.Devices { + existingDevices.Insert(device.Name) + } + for _, device := range desiredSlice.Devices { + if !existingDevices.Has(device.Name) { + return false + } + } + + // Same number of devices, names all present -> equal. + return true +} + +// copyTaintTimeAdded copies existing TimeAdded values from one slice into +// the other if the other one doesn't have it for a taint. Both input +// slices are read-only. +func copyTaintTimeAdded(from, to []resourcev1.Device) []resourcev1.Device { + to = slices.Clone(to) + for i, toDevice := range to { + index := slices.IndexFunc(from, func(fromDevice resourcev1.Device) bool { + return fromDevice.Name == toDevice.Name + }) + if index < 0 { + // No matching device. + continue + } + fromDevice := from[index] + for j, toTaint := range toDevice.Taints { + if toTaint.TimeAdded != nil { + // Already set. + continue + } + // Preserve the old TimeAdded if all other fields are the same. + index := slices.IndexFunc(fromDevice.Taints, func(fromTaint resourcev1.DeviceTaint) bool { + return toTaint.Key == fromTaint.Key && + toTaint.Value == fromTaint.Value && + toTaint.Effect == fromTaint.Effect + }) + if index < 0 { + // No matching old taint. + continue + } + // In practice, devices are unlikely to have many + // taints. Just clone the entire device before we + // motify it, it's unlikely that we do this more than once. + to[i] = *toDevice.DeepCopy() + to[i].Taints[j].TimeAdded = fromDevice.Taints[index].TimeAdded + } + } + return to +} + +// DevicesDeepEqual compares two slices of Devices. It behaves like +// apiequality.Semantic.DeepEqual, with one small difference: +// a nil DeviceTaint.TimeAdded is equal to a non-nil time. +// Also, rounding to full seconds (caused by round-tripping) is +// tolerated. +func DevicesDeepEqual(a, b []resourcev1.Device) bool { + return devicesSemantic.DeepEqual(a, b) +} + +var devicesSemantic = func() conversion.Equalities { + semantic := apiequality.Semantic.Copy() + if err := semantic.AddFunc(deviceTaintEqual); err != nil { + panic(err) + } + return semantic +}() + +func deviceTaintEqual(a, b resourcev1.DeviceTaint) bool { + if a.TimeAdded != nil && b.TimeAdded != nil { + delta := b.TimeAdded.Sub(a.TimeAdded.Time) + if delta < -time.Second || delta > time.Second { + return false + } + } + return a.Key == b.Key && + a.Value == b.Value && + a.Effect == b.Effect +} + +func refIfNotZero[T comparable](t T) *T { + var zero T + if t == zero { + return nil + } + return &t +} diff --git a/images/virtualization-dra/internal/plugin/wrapresourceslice/watcher.go b/images/virtualization-dra/internal/plugin/wrapresourceslice/watcher.go new file mode 100644 index 0000000000..5c1fadf7b4 --- /dev/null +++ b/images/virtualization-dra/internal/plugin/wrapresourceslice/watcher.go @@ -0,0 +1,74 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wrapresourceslice + +import ( + "context" + + "k8s.io/apimachinery/pkg/watch" +) + +func newWrapWatcher(ctx context.Context, w watch.Interface, match func(event watch.Event) bool) watch.Interface { + ctx, cancel := context.WithCancel(ctx) + + watcher := &wrapWatcher{ + watcher: w, + match: match, + ctx: ctx, + cancel: cancel, + result: make(chan watch.Event), + } + go watcher.receive(ctx) + + return watcher +} + +type wrapWatcher struct { + watcher watch.Interface + match func(event watch.Event) bool + + ctx context.Context + cancel context.CancelFunc + result chan watch.Event +} + +func (w *wrapWatcher) receive(ctx context.Context) { + resultChan := w.watcher.ResultChan() + for { + select { + case <-ctx.Done(): + return + case event := <-resultChan: + if w.match == nil || w.match(event) { + w.result <- event + } + } + } +} + +func (w *wrapWatcher) ResultChan() <-chan watch.Event { + return w.result +} + +func (w *wrapWatcher) Stop() { + select { + case <-w.ctx.Done(): + default: + w.watcher.Stop() + w.cancel() + } +} diff --git a/images/virtualization-dra/internal/plugin/wrapresourceslice/zz_generated.deepcopy.go b/images/virtualization-dra/internal/plugin/wrapresourceslice/zz_generated.deepcopy.go new file mode 100644 index 0000000000..0a48df3329 --- /dev/null +++ b/images/virtualization-dra/internal/plugin/wrapresourceslice/zz_generated.deepcopy.go @@ -0,0 +1,132 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright The Kubernetes Authors. +Copyright 2024 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Initially copied from https://github.com/kubernetes/dynamic-resource-allocation/blob/v0.34.2/resourceslice/zz_generated.deepcopy.go +*/ + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package wrapresourceslice + +import ( + v1 "k8s.io/api/core/v1" + resourcev1 "k8s.io/api/resource/v1" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DriverResources) DeepCopyInto(out *DriverResources) { + *out = *in + if in.Pools != nil { + in, out := &in.Pools, &out.Pools + *out = make(map[string]Pool, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DriverResources. +func (in *DriverResources) DeepCopy() *DriverResources { + if in == nil { + return nil + } + out := new(DriverResources) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Owner) DeepCopyInto(out *Owner) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Owner. +func (in *Owner) DeepCopy() *Owner { + if in == nil { + return nil + } + out := new(Owner) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Pool) DeepCopyInto(out *Pool) { + *out = *in + if in.NodeSelector != nil { + in, out := &in.NodeSelector, &out.NodeSelector + *out = new(v1.NodeSelector) + (*in).DeepCopyInto(*out) + } + if in.Slices != nil { + in, out := &in.Slices, &out.Slices + *out = make([]Slice, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Pool. +func (in *Pool) DeepCopy() *Pool { + if in == nil { + return nil + } + out := new(Pool) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Slice) DeepCopyInto(out *Slice) { + *out = *in + if in.Devices != nil { + in, out := &in.Devices, &out.Devices + *out = make([]resourcev1.Device, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.SharedCounters != nil { + in, out := &in.SharedCounters, &out.SharedCounters + *out = make([]resourcev1.CounterSet, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.PerDeviceNodeSelection != nil { + in, out := &in.PerDeviceNodeSelection, &out.PerDeviceNodeSelection + *out = new(bool) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Slice. +func (in *Slice) DeepCopy() *Slice { + if in == nil { + return nil + } + out := new(Slice) + in.DeepCopyInto(out) + return out +} diff --git a/images/virtualization-dra/internal/usb-gateway/attach_record.go b/images/virtualization-dra/internal/usb-gateway/attach_record.go new file mode 100644 index 0000000000..66c47fb539 --- /dev/null +++ b/images/virtualization-dra/internal/usb-gateway/attach_record.go @@ -0,0 +1,179 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbgateway + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "slices" + "sync" + + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +const DefaultRecordStateDir = "/var/run/virtualization-dra/usb" + +type attachRecord struct { + Entries []AttachEntry `json:"entries,omitempty"` +} + +type AttachEntry struct { + Rhport int `json:"rhport"` + BusID string `json:"busID"` + DeviceName string `json:"deviceName"` +} + +func (e AttachEntry) Validate() error { + if e.Rhport < 0 { + return fmt.Errorf("rhport is required") + } + if e.BusID == "" { + return fmt.Errorf("busID is required") + } + if e.DeviceName == "" { + return fmt.Errorf("deviceName is required") + } + return nil +} + +type attachRecordManager struct { + recordFile string + getter usbip.AttachInfoGetter + + mu sync.RWMutex + record attachRecord +} + +func newAttachRecordManager(stateDir string, getter usbip.AttachInfoGetter) (*attachRecordManager, error) { + err := os.MkdirAll(stateDir, 0o700) + if err != nil { + return nil, err + } + + recordFile := filepath.Join(stateDir, "attach-record.json") + if _, err = os.Stat(recordFile); err != nil { + if !os.IsNotExist(err) { + return nil, err + } + + b, err := json.Marshal(attachRecord{}) + if err != nil { + return nil, err + } + err = os.WriteFile(recordFile, b, 0o600) + if err != nil { + return nil, err + } + } + + r := attachRecordManager{ + recordFile: recordFile, + getter: getter, + } + + if err = r.Refresh(); err != nil { + return nil, fmt.Errorf("failed to Refresh record: %w", err) + } + + return &r, nil +} + +func (r *attachRecordManager) Refresh() error { + r.mu.Lock() + defer r.mu.Unlock() + + infos, err := r.getter.GetAttachInfo() + if err != nil { + return err + } + + ports := make(map[int]struct{}, len(infos)) + for _, info := range infos { + ports[info.Port] = struct{}{} + } + + b, err := os.ReadFile(r.recordFile) + if err != nil { + return err + } + + record := attachRecord{} + if err = json.Unmarshal(b, &record); err != nil { + return err + } + + // keep only real entries + var newEntries []AttachEntry + for _, e := range record.Entries { + if _, ok := ports[e.Rhport]; ok { + newEntries = append(newEntries, e) + } + } + + record.Entries = newEntries + + r.record = record + + return nil +} + +func (r *attachRecordManager) GetEntries() []AttachEntry { + r.mu.RLock() + defer r.mu.RUnlock() + + return slices.Clone(r.record.Entries) +} + +func (r *attachRecordManager) AddEntry(e AttachEntry) error { + if err := e.Validate(); err != nil { + return err + } + + r.mu.Lock() + defer r.mu.Unlock() + + for _, entry := range r.record.Entries { + if entry.Rhport == e.Rhport { + return fmt.Errorf("entry with Rhport %d already exists", e.Rhport) + } + if entry.BusID == e.BusID { + return fmt.Errorf("entry with BusID %s already exists", e.BusID) + } + if entry.DeviceName == e.DeviceName { + return fmt.Errorf("entry with DeviceName %s already exists", e.DeviceName) + } + } + + newEntries := slices.Clone(r.record.Entries) + newEntries = append(newEntries, e) + + record := attachRecord{Entries: newEntries} + + b, err := json.Marshal(record) + if err != nil { + return err + } + + if err = os.WriteFile(r.recordFile, b, 0o600); err != nil { + return err + } + + r.record = record + return nil +} diff --git a/images/virtualization-dra/internal/usb-gateway/controller.go b/images/virtualization-dra/internal/usb-gateway/controller.go new file mode 100644 index 0000000000..d3ac9d8efd --- /dev/null +++ b/images/virtualization-dra/internal/usb-gateway/controller.go @@ -0,0 +1,308 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbgateway + +import ( + "context" + "fmt" + "log/slog" + "net" + "strconv" + "sync" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" + + "github.com/deckhouse/virtualization-dra/pkg/controller" + "github.com/deckhouse/virtualization-dra/pkg/patch" + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +const controllerName = "usb-gateway-controller" + +type USBGatewayController struct { + secretName string + namespace string + nodeName string + usbipdAddr string + client kubernetes.Interface + secretIndexer cache.Indexer + resourceSliceIndexer cache.Indexer + usbIP usbip.Interface + queue workqueue.TypedRateLimitingInterface[string] + hasSynced cache.InformerSynced + attachRecordManager *attachRecordManager + + mu sync.RWMutex + nodeAddresses map[string]string + + log *slog.Logger +} + +func NewUSBGatewayController( + ctx context.Context, + secretName, namespace, nodeName, usbipdHost string, + usbipdPort int, + client kubernetes.Interface, + secretInformer, resourceSliceInformer cache.SharedIndexInformer, + usbIP usbip.Interface, +) (*USBGatewayController, error) { + queue := workqueue.NewTypedRateLimitingQueueWithConfig( + workqueue.DefaultTypedControllerRateLimiter[string](), + workqueue.TypedRateLimitingQueueConfig[string]{Name: controllerName}, + ) + + attachRecordManager, err := newAttachRecordManager(DefaultRecordStateDir, usbIP) + if err != nil { + return nil, err + } + + c := &USBGatewayController{ + secretName: secretName, + namespace: namespace, + nodeName: nodeName, + usbipdAddr: net.JoinHostPort(usbipdHost, strconv.Itoa(usbipdPort)), + client: client, + secretIndexer: secretInformer.GetIndexer(), + resourceSliceIndexer: resourceSliceInformer.GetIndexer(), + usbIP: usbIP, + queue: queue, + log: slog.With(slog.String("controller", controllerName)), + attachRecordManager: attachRecordManager, + } + + _, err = secretInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: c.addSecret, + UpdateFunc: c.updateSecret, + DeleteFunc: c.deleteSecret, + }) + if err != nil { + return nil, fmt.Errorf("unable to add event handler to secret informer: %w", err) + } + + c.hasSynced = func() bool { + return secretInformer.HasSynced() && resourceSliceInformer.HasSynced() + } + + err = c.runSecretChecker(ctx) + if err != nil { + return nil, fmt.Errorf("failed to run secret checker: %w", err) + } + + return c, nil +} + +func (c *USBGatewayController) runSecretChecker(ctx context.Context) error { + ticker := time.NewTicker(time.Second * 30) + defer ticker.Stop() + + if !cache.WaitForCacheSync(ctx.Done(), c.hasSynced) { + return fmt.Errorf("failed to wait for caches to sync") + } + + key := controller.KeyFunc(c.namespace, c.secretName) + c.queueAdd(key) + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + exist, err := c.secretExists(key) + if err != nil { + c.log.Error("Failed to check secret existence", slog.Any("error", err)) + } + if !exist { + c.queueAdd(key) + } + } + } + }() + + return nil +} + +func (c *USBGatewayController) addSecret(obj interface{}) { + if secret, ok := obj.(*corev1.Secret); ok && c.isMySecret(secret) { + c.enqueueSecret(secret) + } else if !ok { + c.log.Error("expected secret, got", slog.Any("obj", obj)) + } +} + +func (c *USBGatewayController) deleteSecret(obj interface{}) { + if secret, ok := obj.(*corev1.Secret); ok && c.isMySecret(secret) { + c.enqueueSecret(secret) + } else if !ok { + c.log.Error("expected secret, got", slog.Any("obj", obj)) + } +} + +func (c *USBGatewayController) updateSecret(oldObj, newObj interface{}) { + oldSecret, oldOk := oldObj.(*corev1.Secret) + newSecret, newOk := newObj.(*corev1.Secret) + + if !oldOk || !newOk { + c.log.Error("expected secret, got", slog.Any("old", oldObj), slog.Any("new", newObj)) + return + } + + if c.isMySecret(newSecret) && !equality.Semantic.DeepEqual(oldSecret.Data, newSecret.Data) { + c.enqueueSecret(newSecret) + } +} + +func (c *USBGatewayController) isMySecret(secret *corev1.Secret) bool { + return secret.Name == c.secretName && secret.Namespace == c.namespace +} + +func (c *USBGatewayController) isMySecretKey(key string) bool { + return key == controller.KeyFunc(c.namespace, c.secretName) +} + +func (c *USBGatewayController) enqueueSecret(secret *corev1.Secret) { + c.queueAdd(controller.MetaObjectKeyFunc(secret)) +} + +func (c *USBGatewayController) queueAdd(key string) { + c.queue.Add(key) +} + +func (c *USBGatewayController) Queue() workqueue.TypedRateLimitingInterface[string] { + return c.queue +} + +func (c *USBGatewayController) HasSynced() bool { + return c.hasSynced() +} + +func (c *USBGatewayController) Logger() *slog.Logger { + return c.log +} + +func (c *USBGatewayController) Sync(ctx context.Context, key string) error { + log := c.log.With("key", key) + log.Info("syncing resource claim") + + if !c.isMySecretKey(key) { + log.Error("False try reconcile other secret, please report a bug") + return nil + } + + secret, err := c.getSecret(key) + if err != nil { + return err + } + if secret == nil { + return c.createSecret(ctx) + } + + secret, err = c.ensureAddress(ctx, secret) + if err != nil { + return err + } + + c.syncAddresses(secret) + + return nil +} + +func (c *USBGatewayController) secretExists(key string) (bool, error) { + _, exists, err := c.secretIndexer.GetByKey(key) + return exists, err +} + +func (c *USBGatewayController) getSecret(key string) (*corev1.Secret, error) { + obj, exists, err := c.secretIndexer.GetByKey(key) + if err != nil { + return nil, fmt.Errorf("failed to get secret: %w", err) + } + if !exists { + return nil, nil + } + secret, ok := obj.(*corev1.Secret) + if !ok { + return nil, fmt.Errorf("expected secret, got %T", obj) + } + return secret.DeepCopy(), nil +} + +func (c *USBGatewayController) createSecret(ctx context.Context) error { + secret := &corev1.Secret{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: corev1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: c.secretName, + Namespace: c.namespace, + }, + Data: map[string][]byte{ + c.nodeName: []byte(c.usbipdAddr), + }, + } + + _, err := c.client.CoreV1().Secrets(c.namespace).Create(ctx, secret, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("failed to create secret: %w", err) + } + + return nil +} + +func (c *USBGatewayController) ensureAddress(ctx context.Context, secret *corev1.Secret) (*corev1.Secret, error) { + addr, exists := secret.Data[c.nodeName] + if string(addr) == c.usbipdAddr { + return secret, nil + } + + jp := patch.NewJSONPatch(patch.WithTest("/data/"+c.nodeName, addr)) + if exists { + jp.Append(patch.WithReplace("/data/"+c.nodeName, []byte(c.usbipdAddr))) + } else { + jp.Append(patch.WithAdd("/data/"+c.nodeName, []byte(c.usbipdAddr))) + } + + bytes, err := jp.Bytes() + if err != nil { + return nil, fmt.Errorf("failed to generate patch: %w", err) + } + + newSecret, err := c.client.CoreV1().Secrets(c.namespace).Patch(ctx, secret.Name, types.JSONPatchType, bytes, metav1.PatchOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to patch secret: %w", err) + } + + return newSecret, nil +} + +func (c *USBGatewayController) syncAddresses(secret *corev1.Secret) { + newAddresses := make(map[string]string, len(secret.Data)) + for node, addr := range secret.Data { + newAddresses[node] = string(addr) + } + c.mu.Lock() + defer c.mu.Unlock() + c.nodeAddresses = newAddresses +} diff --git a/images/virtualization-dra/internal/usb-gateway/informer/informer.go b/images/virtualization-dra/internal/usb-gateway/informer/informer.go new file mode 100644 index 0000000000..a0876c484f --- /dev/null +++ b/images/virtualization-dra/internal/usb-gateway/informer/informer.go @@ -0,0 +1,140 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package informer + +import ( + "context" + "log/slog" + "math/rand/v2" + "sync" + "time" + + "golang.org/x/sync/errgroup" + corev1 "k8s.io/api/core/v1" + resourcev1 "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" +) + +const ( + PoolIndex = "pool" + DriverIndex = "driver" +) + +func NewFactory(clientSet kubernetes.Interface, resync *time.Duration) *Factory { + var defaultResync time.Duration + if resync != nil { + defaultResync = *resync + } else { + defaultResync = resyncPeriod(12 * time.Hour) + } + + return &Factory{ + clientSet: clientSet, + defaultResync: defaultResync, + informers: make(map[string]cache.SharedIndexInformer), + startedInformers: make(map[string]struct{}), + } +} + +type Factory struct { + clientSet kubernetes.Interface + defaultResync time.Duration + + informers map[string]cache.SharedIndexInformer + startedInformers map[string]struct{} + mu sync.Mutex +} + +func (f *Factory) Run(ctx context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + + group, ctx := errgroup.WithContext(ctx) + + for name, informer := range f.informers { + if _, found := f.startedInformers[name]; found { + // skip informers that have already started. + slog.Info("SKIPPING informer", slog.String("name", name)) + continue + } + slog.Info("STARTING informer", slog.String("name", name)) + group.Go(func() error { + informer.Run(ctx.Done()) + return nil + }) + f.startedInformers[name] = struct{}{} + } + + return group.Wait() +} + +func (f *Factory) WaitForCacheSync(stopCh <-chan struct{}) { + var syncs []cache.InformerSynced + + f.mu.Lock() + for name, informer := range f.informers { + slog.Info("Waiting for cache sync of informer", slog.String("name", name)) + syncs = append(syncs, informer.HasSynced) + } + f.mu.Unlock() + + cache.WaitForCacheSync(stopCh, syncs...) +} + +func (f *Factory) ResourceSlice() cache.SharedIndexInformer { + return f.getInformer("resourceSliceInformer", func() cache.SharedIndexInformer { + lw := cache.NewListWatchFromClient(f.clientSet.ResourceV1().RESTClient(), "resourceslices", corev1.NamespaceAll, fields.Everything()) + return cache.NewSharedIndexInformer(lw, &resourcev1.ResourceSlice{}, f.defaultResync, cache.Indexers{ + PoolIndex: func(obj interface{}) ([]string, error) { + return []string{obj.(*resourcev1.ResourceSlice).Spec.Pool.Name}, nil + }, + DriverIndex: func(obj interface{}) ([]string, error) { + return []string{obj.(*resourcev1.ResourceSlice).Spec.Driver}, nil + }, + }) + }) +} + +func (f *Factory) NamespacedSecret(namespace string) cache.SharedIndexInformer { + return f.getInformer("namespacedSecretInformer", func() cache.SharedIndexInformer { + lw := cache.NewListWatchFromClient(f.clientSet.CoreV1().RESTClient(), "secrets", namespace, fields.Everything()) + return cache.NewSharedIndexInformer(lw, &corev1.Secret{}, f.defaultResync, cache.Indexers{}) + }) +} + +func (f *Factory) getInformer(key string, newFunc func() cache.SharedIndexInformer) cache.SharedIndexInformer { + f.mu.Lock() + defer f.mu.Unlock() + + informer, ok := f.informers[key] + if ok { + return informer + } + + informer = newFunc() + f.informers[key] = informer + + return informer +} + +// resyncPeriod computes the time interval a shared informer waits before resyncing with the api server +func resyncPeriod(minResyncPeriod time.Duration) time.Duration { + factor := rand.Float64() + 1 + return time.Duration(float64(minResyncPeriod.Nanoseconds()) * factor) +} diff --git a/images/virtualization-dra/internal/usb-gateway/labeler/labeler.go b/images/virtualization-dra/internal/usb-gateway/labeler/labeler.go new file mode 100644 index 0000000000..f7f3fa94c8 --- /dev/null +++ b/images/virtualization-dra/internal/usb-gateway/labeler/labeler.go @@ -0,0 +1,99 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package labeler + +import ( + "context" + "fmt" + "maps" + + "k8s.io/apimachinery/pkg/api/equality" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/dynamic" + + "github.com/deckhouse/virtualization-dra/pkg/patch" +) + +type Labeler interface { + Label(ctx context.Context, name, namespace string, addLabels map[string]string, removeLabels []string) error +} + +type genericLabeler struct { + client dynamic.Interface + gvr schema.GroupVersionResource +} + +func NewGenericLabeler(client dynamic.Interface, gvr schema.GroupVersionResource) Labeler { + return &genericLabeler{ + client: client, + gvr: gvr, + } +} + +func (l *genericLabeler) Label(ctx context.Context, name, namespace string, addLabels map[string]string, removeLabels []string) error { + if addLabels == nil && removeLabels == nil { + return nil + } + + obj, err := l.client.Resource(l.gvr).Namespace(namespace).Get(ctx, name, metav1.GetOptions{}) + if err != nil { + return err + } + + oldLabels := obj.GetLabels() + newLabels := make(map[string]string) + maps.Copy(newLabels, oldLabels) + for _, k := range removeLabels { + delete(newLabels, k) + } + maps.Copy(newLabels, addLabels) + + if equality.Semantic.DeepEqual(oldLabels, newLabels) { + return nil + } + + patchBytes, err := patch.NewJSONPatch( + patch.WithTest("/metadata/labels", oldLabels), + patch.WithReplace("/metadata/labels", newLabels), + ).Bytes() + if err != nil { + return fmt.Errorf("failed to create patch: %w", err) + } + + _, err = l.client.Resource(l.gvr).Namespace(namespace).Patch(ctx, name, types.JSONPatchType, patchBytes, metav1.PatchOptions{}) + return err +} + +type NodeLabeler struct { + generic Labeler +} + +func NewNodeLabeler(client dynamic.Interface) NodeLabeler { + return NodeLabeler{ + generic: NewGenericLabeler(client, schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "nodes", + }), + } +} + +func (l NodeLabeler) Label(ctx context.Context, name, namespace string, addLabels map[string]string, removeLabels []string) error { + return l.generic.Label(ctx, name, namespace, addLabels, removeLabels) +} diff --git a/images/virtualization-dra/internal/usb-gateway/mark.go b/images/virtualization-dra/internal/usb-gateway/mark.go new file mode 100644 index 0000000000..ddaf930b10 --- /dev/null +++ b/images/virtualization-dra/internal/usb-gateway/mark.go @@ -0,0 +1,57 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbgateway + +import ( + "context" + "fmt" + + "k8s.io/client-go/dynamic" + + "github.com/deckhouse/virtualization-dra/internal/consts" + "github.com/deckhouse/virtualization-dra/internal/usb-gateway/labeler" +) + +type Marker struct { + nodeName string + labeler labeler.NodeLabeler +} + +func NewMarker(dynamicClient dynamic.Interface, nodeName string) *Marker { + return &Marker{ + nodeName: nodeName, + labeler: labeler.NewNodeLabeler(dynamicClient), + } +} + +func (m Marker) Mark(ctx context.Context) error { + err := m.labeler.Label(ctx, m.nodeName, "", map[string]string{ + consts.USBGatewayLabel: "true", + }, nil) + if err != nil { + return fmt.Errorf("failed to label node %s: %w", m.nodeName, err) + } + return nil +} + +func (m Marker) Unmark(ctx context.Context) error { + err := m.labeler.Label(ctx, m.nodeName, "", nil, []string{consts.USBGatewayLabel}) + if err != nil { + return fmt.Errorf("failed to unlabel node %s: %w", m.nodeName, err) + } + return nil +} diff --git a/images/virtualization-dra/internal/usb-gateway/usbgateway.go b/images/virtualization-dra/internal/usb-gateway/usbgateway.go new file mode 100644 index 0000000000..35dbebc356 --- /dev/null +++ b/images/virtualization-dra/internal/usb-gateway/usbgateway.go @@ -0,0 +1,295 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbgateway + +import ( + "context" + "fmt" + "log/slog" + "net" + "strconv" + "time" + + resourcev1 "k8s.io/api/resource/v1" + "k8s.io/apimachinery/pkg/util/wait" + + "github.com/deckhouse/virtualization-dra/internal/consts" + "github.com/deckhouse/virtualization-dra/internal/usb-gateway/informer" + "github.com/deckhouse/virtualization-dra/pkg/usbip" +) + +type USBGateway interface { + Attach(ctx context.Context, deviceName string) error + Detach(deviceName string) error + GetAttachedBusID(deviceName string) (string, error) + GetAttachedDeviceNames() (map[string]struct{}, error) +} + +func (c *USBGatewayController) Attach(ctx context.Context, deviceName string) error { + busID, host, port, err := c.getDeps(deviceName) + if err != nil { + return fmt.Errorf("failed to get attach deps: %w", err) + } + + log := c.log.With( + slog.String("deviceName", deviceName), + slog.String("busID", busID), + slog.String("host", host), + slog.Int("port", port), + ) + + err = c.attachRecordManager.Refresh() + if err != nil { + return fmt.Errorf("failed to Refresh attach record: %w", err) + } + + if entry := c.findEntry(deviceName); entry != nil { + log.Info("Device is already attached", slog.Any("entry", entry)) + return nil + } + + log.Info("Exporting USB device") + err = c.usbIP.Export(host, busID, port) + if err != nil { + return fmt.Errorf("failed to export device %s: %w", deviceName, err) + } + + log.Info("Attaching USB device") + rhport, err := c.usbIP.Attach(host, busID, port) + if err != nil { + return fmt.Errorf("failed to attach device %s: %w", deviceName, err) + } + + usedInfo, err := c.waitUsbAttachInfo(ctx, rhport) + if err != nil { + return fmt.Errorf("failed to wait for usb attach info: %w, detach error: %w", err, c.usbIP.Detach(rhport)) + } + + return c.storeAttachRecordOrDetach(deviceName, usedInfo.LocalBusID, rhport) +} + +func (c *USBGatewayController) Detach(deviceName string) error { + busID, host, port, err := c.getDeps(deviceName) + if err != nil { + return err + } + + log := c.log.With( + slog.String("deviceName", deviceName), + slog.String("busID", busID), + slog.String("host", host), + slog.Int("port", port), + ) + + err = c.attachRecordManager.Refresh() + if err != nil { + return fmt.Errorf("failed to Refresh attach record: %w", err) + } + + entry := c.findEntry(deviceName) + if entry != nil { + log.Info("Detaching USB device") + err = c.usbIP.Detach(entry.Rhport) + if err != nil { + return fmt.Errorf("failed to detach device %s: %w", deviceName, err) + } + } + + log.Info("Unexporting USB device") + err = c.usbIP.Unexport(host, busID, port) + if err != nil { + return fmt.Errorf("failed to unexport device %s: %w", deviceName, err) + } + + return nil +} + +func (c *USBGatewayController) GetAttachedBusID(deviceName string) (string, error) { + if err := c.attachRecordManager.Refresh(); err != nil { + return "", fmt.Errorf("failed to Refresh attach record: %w", err) + } + + if entry := c.findEntry(deviceName); entry != nil { + return entry.BusID, nil + } + + return "", fmt.Errorf("device %s is not attached", deviceName) +} + +func (c *USBGatewayController) GetAttachedDeviceNames() (map[string]struct{}, error) { + if err := c.attachRecordManager.Refresh(); err != nil { + return nil, fmt.Errorf("failed to Refresh attach record: %w", err) + } + + entries := c.attachRecordManager.GetEntries() + + names := make(map[string]struct{}, len(entries)) + for _, entry := range entries { + names[entry.DeviceName] = struct{}{} + } + + return names, nil +} + +func (c *USBGatewayController) getDeps(deviceName string) (string, string, int, error) { + device, pool, err := c.getDevice(deviceName) + if err != nil { + return "", "", -1, err + } + + busID, err := c.getBusID(device) + if err != nil { + return "", "", -1, err + } + + host, port, err := c.resolveRemoteAddress(pool) + if err != nil { + return "", "", -1, err + } + + return busID, host, port, nil +} + +func (c *USBGatewayController) getDevice(deviceName string) (*resourcev1.Device, string, error) { + resourceSlices, err := c.getVirtualizationDraResourceSlices() + if err != nil { + return nil, "", err + } + + for _, slice := range resourceSlices { + for _, device := range slice.Spec.Devices { + if device.Name == deviceName { + if slice.Spec.Pool.Name == c.nodeName { + return nil, "", fmt.Errorf("device is not allowed to be attached to itself") + } + + return &device, slice.Spec.Pool.Name, nil + } + } + } + + return nil, "", fmt.Errorf("device %s is not found", deviceName) +} + +func (c *USBGatewayController) getVirtualizationDraResourceSlices() ([]resourcev1.ResourceSlice, error) { + slicesObj, err := c.resourceSliceIndexer.ByIndex(informer.DriverIndex, consts.VirtualizationDraUSBDriverName) + if err != nil { + return nil, err + } + var slices []resourcev1.ResourceSlice + for _, obj := range slicesObj { + slice, ok := obj.(*resourcev1.ResourceSlice) + if !ok { + return nil, fmt.Errorf("unexpected type of resource slice: %T", obj) + } + slices = append(slices, *slice.DeepCopy()) + } + return slices, nil +} + +func (c *USBGatewayController) getBusID(device *resourcev1.Device) (string, error) { + if attr, ok := device.Attributes[consts.AttrBusID]; ok && attr.StringValue != nil { + return *attr.StringValue, nil + } + return "", fmt.Errorf("busID attribute is not exist") +} + +func (c *USBGatewayController) resolveRemoteAddress(pool string) (string, int, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + addr, exist := c.nodeAddresses[pool] + if !exist { + return "", -1, fmt.Errorf("pool %s is not found", pool) + } + + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", -1, fmt.Errorf("failed to split host and port: %w", err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return "", -1, fmt.Errorf("failed to parse port: %w", err) + } + + return host, port, nil +} + +func (c *USBGatewayController) storeAttachRecordOrDetach(deviceName, busID string, rhport int) (err error) { + entry := AttachEntry{ + Rhport: rhport, + BusID: busID, + DeviceName: deviceName, + } + + const maxRetries = 3 + + for range maxRetries { + c.log.Info("Adding entry to attach record", slog.Any("entry", entry)) + err = c.attachRecordManager.AddEntry(entry) + if err == nil { + return nil + } + + c.log.Error("Failed to add entry to attach record", slog.Any("error", err)) + } + + for range maxRetries { + c.log.Info("Detaching device", slog.Any("deviceName", deviceName)) + err = c.usbIP.Detach(rhport) + if err == nil { + return fmt.Errorf("failed to store attach record: %w", err) + } + + c.log.Error("Failed to detach device", slog.Any("error", err)) + } + + return fmt.Errorf("failed to detach device: %w", err) +} + +func (c *USBGatewayController) findEntry(deviceName string) *AttachEntry { + for _, entry := range c.attachRecordManager.GetEntries() { + if entry.DeviceName == deviceName { + return &entry + } + } + return nil +} + +func (c *USBGatewayController) waitUsbAttachInfo(ctx context.Context, rhport int) (*usbip.AttachInfo, error) { + // command attach was successful, but we need to wait until usb is real attached + var usedInfo *usbip.AttachInfo + + err := wait.PollUntilContextCancel(ctx, time.Second, true, func(ctx context.Context) (bool, error) { + c.log.Info("Get attach info for store localBusID") + infos, err := c.usbIP.GetAttachInfo() + if err != nil { + c.log.Info("Failed to get used info", slog.String("error", err.Error())) + return false, nil + } + for _, info := range infos { + if info.Port == rhport { + usedInfo = &info + return true, nil + } + } + c.log.Info("Usb are not attached yet") + return false, nil + }) + + return usedInfo, err +} diff --git a/images/virtualization-dra/internal/usb/convert.go b/images/virtualization-dra/internal/usb/convert.go index adc1f7fbeb..d38fd2e96d 100644 --- a/images/virtualization-dra/internal/usb/convert.go +++ b/images/virtualization-dra/internal/usb/convert.go @@ -19,10 +19,12 @@ package usb import ( "fmt" + corev1 "k8s.io/api/core/v1" resourcev1 "k8s.io/api/resource/v1" "k8s.io/utils/ptr" "github.com/deckhouse/virtualization-dra/internal/consts" + "github.com/deckhouse/virtualization-dra/internal/featuregates" ) func (d *Device) ToAPIDevice(nodeName string) *resourcev1.Device { @@ -77,10 +79,43 @@ func convertToAPIDevice(usbDevice Device, nodeName string) *resourcev1.Device { StringValue: ptr.To(usbDevice.DevicePath), }, consts.AttrUsbAddress: { - StringValue: ptr.To(fmt.Sprintf("%s:%s", usbDevice.Bus, usbDevice.DeviceNumber)), + StringValue: ptr.To(usbAddressFromDev(&usbDevice)), }, }, } + if !featuregates.Default().USBGatewayEnabled() { + device.NodeName = ptr.To(nodeName) + } + + if featuregates.Default().USBNodeLocalMultiAllocationEnabled() { + device.BindsToNode = ptr.To(true) // Required DRADeviceBindingConditions,DRAResourceClaimDeviceStatus + device.AllowMultipleAllocations = ptr.To(true) // Required DRAConsumableCapacity + } + return device } + +func getNodeSelector() *corev1.NodeSelector { + return &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: consts.USBGatewayLabel, + Operator: corev1.NodeSelectorOpIn, + Values: []string{"true"}, + }, + }, + }, + }, + } +} + +func usbAddressFromDev(dev *Device) string { + return usbAddress(dev.Bus.String(), dev.DeviceNumber.String()) +} + +func usbAddress(bus, deviceNumber string) string { + return fmt.Sprintf("%s:%s", bus, deviceNumber) +} diff --git a/images/virtualization-dra/internal/usb/device.go b/images/virtualization-dra/internal/usb/device.go index d75aeb7609..8d63d5c7a9 100644 --- a/images/virtualization-dra/internal/usb/device.go +++ b/images/virtualization-dra/internal/usb/device.go @@ -51,9 +51,9 @@ type Device struct { } func (d *Device) GetName(nodeName string) string { - // usb----- - // usb-003-005-e39-f100 - unhashed := fmt.Sprintf("%s-%s-%s-%s-%s", d.Bus.String(), d.DeviceNumber.String(), d.VendorID.String(), d.ProductID.String(), nodeName) + // usb---- + // usb-3-1-e39-f100 + unhashed := fmt.Sprintf("%s-%s-%s-%s", d.BusID, d.VendorID.String(), d.ProductID.String(), nodeName) hash := sha1.Sum([]byte(unhashed)) hashedString := hex.EncodeToString(hash[:]) diff --git a/images/virtualization-dra/internal/usb/discovery.go b/images/virtualization-dra/internal/usb/discovery.go index 1310b56b54..9fde30965f 100644 --- a/images/virtualization-dra/internal/usb/discovery.go +++ b/images/virtualization-dra/internal/usb/discovery.go @@ -16,11 +16,34 @@ limitations under the License. package usb -func (s *AllocationStore) discoveryPluggedUSBDevices() DeviceSet { +import ( + "github.com/deckhouse/virtualization-dra/internal/featuregates" +) + +func (s *AllocationStore) discoveryPluggedUSBDevices() (DeviceSet, DeviceSet, error) { allUSBDevices := s.monitor.GetDevices() + + busIDSet := make(map[string]struct{}) + if featuregates.Default().USBGatewayEnabled() { + infos, err := s.usbipInfoGetter.GetAttachInfo() + if err != nil { + return nil, nil, err + } + for _, info := range infos { + busIDSet[info.LocalBusID] = struct{}{} + } + } + usbDeviceSet := NewDeviceSet() + usbipDeviceSet := NewDeviceSet() + for _, device := range allUSBDevices { - usbDeviceSet.Insert(toDevice(&device)) + if _, ok := busIDSet[device.BusID]; ok { + usbipDeviceSet.Insert(toDevice(&device)) + } else { + usbDeviceSet.Insert(toDevice(&device)) + } } - return usbDeviceSet + + return usbDeviceSet, usbipDeviceSet, nil } diff --git a/images/virtualization-dra/internal/usb/driver.go b/images/virtualization-dra/internal/usb/driver.go index b4c6fbb20f..34357f7aba 100644 --- a/images/virtualization-dra/internal/usb/driver.go +++ b/images/virtualization-dra/internal/usb/driver.go @@ -16,4 +16,6 @@ limitations under the License. package usb -const DriverName = "virtualization-usb" +import "github.com/deckhouse/virtualization-dra/internal/consts" + +const DriverName = consts.VirtualizationDraUSBDriverName diff --git a/images/virtualization-dra/internal/usb/store.go b/images/virtualization-dra/internal/usb/store.go index b5d8b4e753..5ad257a9fb 100644 --- a/images/virtualization-dra/internal/usb/store.go +++ b/images/virtualization-dra/internal/usb/store.go @@ -17,16 +17,23 @@ limitations under the License. package usb import ( + "cmp" "context" + "encoding/json" "fmt" "log/slog" + "path/filepath" + "slices" + "strconv" "strings" "sync" "github.com/containerd/nri/pkg/api" resourcev1 "k8s.io/api/resource/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/resourceslice" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/utils/ptr" @@ -35,20 +42,28 @@ import ( "github.com/deckhouse/virtualization-dra/internal/cdi" "github.com/deckhouse/virtualization-dra/internal/consts" + "github.com/deckhouse/virtualization-dra/internal/featuregates" + usbgateway "github.com/deckhouse/virtualization-dra/internal/usb-gateway" "github.com/deckhouse/virtualization-dra/pkg/libusb" + "github.com/deckhouse/virtualization-dra/pkg/patch" + "github.com/deckhouse/virtualization-dra/pkg/usbip" ) -func NewAllocationStore(ctx context.Context, nodeName string, cdiManager cdi.Manager, monitor libusb.Monitor) (*AllocationStore, error) { +func NewAllocationStore(ctx context.Context, nodeName string, cdiManager cdi.Manager, monitor libusb.Monitor, usbGateway usbgateway.USBGateway, kubeClient kubernetes.Interface) (*AllocationStore, error) { store := &AllocationStore{ - nodeName: nodeName, - cdi: cdiManager, - monitor: monitor, - log: slog.With(slog.String("component", "usb-allocation-store")), - updateChannel: make(chan resourceslice.DriverResources, 2), - discoverPluggedUSBDevices: NewDeviceSet(), - allocatableDevices: make(map[string]resourcev1.Device), - allocatedDevices: sets.New[string](), - resourceClaimAllocations: make(map[types.UID][]string), + nodeName: nodeName, + cdi: cdiManager, + monitor: monitor, + usbGateway: usbGateway, + kubeClient: kubeClient, + log: slog.With(slog.String("component", "usb-allocation-store")), + updateChannel: make(chan resourceslice.DriverResources, 2), + discoverPluggedUSBDevices: NewDeviceSet(), + allocatableDevices: make(map[string]resourcev1.Device), + allocatedDevices: sets.New[string](), + usbipAllocatedDevicesCount: make(map[string]int), + resourceClaimAllocations: make(map[types.UID][]string), + usbipInfoGetter: usbip.NewUSBAttacher(), } store.subscribeToDeviceChanges(ctx) @@ -70,20 +85,30 @@ type AllocationStore struct { updateChannel chan resourceslice.DriverResources mu sync.RWMutex - monitor libusb.Monitor + usbGateway usbgateway.USBGateway + usbipInfoGetter usbip.AttachInfoGetter + monitor libusb.Monitor + kubeClient kubernetes.Interface - discoverPluggedUSBDevices DeviceSet - allocatableDevices map[string]resourcev1.Device + discoverPluggedUSBDevices DeviceSet + discoverUsbIpPluggedUSBDevices DeviceSet + allocatableDevices map[string]resourcev1.Device - allocatedDevices sets.Set[string] - resourceClaimAllocations map[types.UID][]string + allocatedDevices sets.Set[string] + usbipAllocatedDevicesCount map[string]int + resourceClaimAllocations map[types.UID][]string } func (s *AllocationStore) sync() error { s.mu.Lock() defer s.mu.Unlock() - discoverPluggedUSBDevices := s.discoveryPluggedUSBDevices() + discoverPluggedUSBDevices, discoverUsbIpPluggedUSBDevices, err := s.discoveryPluggedUSBDevices() + if err != nil { + return err + } + + s.discoverUsbIpPluggedUSBDevices = discoverUsbIpPluggedUSBDevices if discoverPluggedUSBDevices.Equal(s.discoverPluggedUSBDevices) { return nil @@ -135,7 +160,7 @@ func (s *AllocationStore) UpdateChannel() chan resourceslice.DriverResources { return s.updateChannel } -func (s *AllocationStore) Prepare(_ context.Context, claim *resourcev1.ResourceClaim) ([]*drapbv1.Device, error) { +func (s *AllocationStore) Prepare(ctx context.Context, claim *resourcev1.ResourceClaim) ([]*drapbv1.Device, error) { s.mu.Lock() defer s.mu.Unlock() @@ -147,19 +172,66 @@ func (s *AllocationStore) Prepare(_ context.Context, claim *resourcev1.ResourceC preparedDevices := make(cdi.PreparedDevices, len(claim.Status.Allocation.Devices.Results)) + usbGatewayEnabled := featuregates.Default().USBGatewayEnabled() + usbNodeLocalMultiAllocationEnabled := featuregates.Default().USBNodeLocalMultiAllocationEnabled() + + usbIPAllocatedDevices := make(map[string]struct{}) + usbDeviceInfos := make([]usbDeviceInfo, 0) + for i, result := range claim.Status.Allocation.Devices.Results { - if s.allocatedDevices.Has(result.Device) { - return nil, fmt.Errorf("device %v is already allocated", result.Device) + allocated := s.allocatedDevices.Has(result.Device) + if allocated && !usbNodeLocalMultiAllocationEnabled { + return nil, fmt.Errorf("device %v is already allocated. For USB node local multi allocation, please set feature gate %q to true", result.Device, featuregates.USBNodeLocalMultiAllocation) } - usbDevice, exists := s.allocatableDevices[result.Device] - if !exists { - return nil, fmt.Errorf("requested device is not allocatable: %v", result.Device) + isUSBGatewayRequest := s.isUSBGatewayRequest(&result) + + if !usbGatewayEnabled && isUSBGatewayRequest { + return nil, fmt.Errorf("claim %s/%s has usb gateway request but usb gateway is disabled", claim.Namespace, claim.Name) } - containerEditsOptions, err := newContainerEditsOptions(&usbDevice) - if err != nil { - return nil, err + var containerEditsOptions containerEditsOptions + + if isUSBGatewayRequest { + if !allocated { + err := s.usbGateway.Attach(ctx, result.Device) + if err != nil { + return nil, err + } + } + + busID, err := s.usbGateway.GetAttachedBusID(result.Device) + if err != nil { + return nil, err + } + + usbDevice := s.getUsbGatewayUsbDevice(busID) + if usbDevice == nil { + return nil, fmt.Errorf("usb device %s is not found", busID) + } + + containerEditsOptions = newContainerEditsOptionsForUSBGateway(result.Device, usbDevice).withUserGroup(claim) + + usbIPAllocatedDevices[result.Device] = struct{}{} + usbDeviceInfos = append(usbDeviceInfos, usbDeviceInfo{ + DeviceName: result.Device, + UsbAddress: usbAddressFromDev(usbDevice), + }) + } else { + usbDevice, exists := s.allocatableDevices[result.Device] + if !exists { + return nil, fmt.Errorf("requested device is not allocatable: %v", result.Device) + } + + opts, err := newContainerEditsOptions(&usbDevice) + if err != nil { + return nil, err + } + containerEditsOptions = opts.withUserGroup(claim) + usbDeviceInfos = append(usbDeviceInfos, usbDeviceInfo{ + DeviceName: result.Device, + UsbAddress: usbAddress(opts.Bus, opts.DeviceNum), + }) } edits := s.makeContainerEdits(claimUID, containerEditsOptions) @@ -181,15 +253,130 @@ func (s *AllocationStore) Prepare(_ context.Context, claim *resourcev1.ResourceC return nil, fmt.Errorf("unable to create CDI spec file for claim: %w", err) } + err = s.ensureAnnotationDeviceAddresses(ctx, claim, usbDeviceInfos) + if err != nil { + return nil, err + } + devices := preparedDevices.GetDevices() for _, device := range devices { s.allocatedDevices.Insert(device.DeviceName) s.resourceClaimAllocations[claim.UID] = append(s.resourceClaimAllocations[claim.UID], device.DeviceName) + + if _, ok := usbIPAllocatedDevices[device.DeviceName]; ok { + s.usbipAllocatedDevicesCount[device.DeviceName]++ + } } return devices, nil } +func (s *AllocationStore) getUsbGatewayUsbDevice(busID string) *Device { + for _, device := range s.discoverUsbIpPluggedUSBDevices.UnsortedList() { + if device.BusID == busID { + return &device + } + } + // usb device is not found in cache + // load usb device from sysfs + dev, err := libusb.LoadUSBDevice(filepath.Join(libusb.PathToUSBDevices, busID)) + if err == nil { + return ptr.To(toDevice(&dev)) + } + + return nil +} + +func (s *AllocationStore) ensureAnnotationDeviceAddresses(ctx context.Context, claim *resourcev1.ResourceClaim, usbDeviceInfos []usbDeviceInfo) error { + path := fmt.Sprintf("/metadata/annotations/%s", patch.EscapeJSONPointer(consts.AnnUSBDeviceAddresses)) + + slices.SortFunc(usbDeviceInfos, func(a, b usbDeviceInfo) int { + return cmp.Compare(a.DeviceName, b.DeviceName) + }) + + oldAnno, oldUsbDeviceInfos, err := loadUsbDeviceInfos(claim) + if err != nil { + return err + } + + jp := patch.NewJSONPatch() + + if oldUsbDeviceInfos == nil { + jp.Append(patch.WithAdd(path, patch.AsJsonString{Data: usbDeviceInfos})) + } else { + slices.SortFunc(oldUsbDeviceInfos, func(a, b usbDeviceInfo) int { + return cmp.Compare(a.DeviceName, b.DeviceName) + }) + if slices.Equal(oldUsbDeviceInfos, usbDeviceInfos) { + return nil + } + + jp.Append( + patch.WithTest(path, oldAnno), + patch.WithReplace(path, patch.AsJsonString{Data: usbDeviceInfos}), + ) + } + + bytes, err := jp.Bytes() + if err != nil { + return fmt.Errorf("failed to generate patch: %w", err) + } + + s.log.Debug("Patching resource claim", slog.String("uid", string(claim.UID)), slog.Any("patch", string(bytes))) + _, err = s.kubeClient.ResourceV1().ResourceClaims(claim.Namespace).Patch(ctx, claim.Name, types.JSONPatchType, bytes, metav1.PatchOptions{}) + if err != nil { + return fmt.Errorf("failed to patch resource claim: %w", err) + } + return nil +} + +type usbDeviceInfo struct { + DeviceName string `json:"deviceName"` + UsbAddress string `json:"usbAddress"` +} + +func loadUsbDeviceInfos(obj metav1.Object) (string, []usbDeviceInfo, error) { + var usbDeviceInfos []usbDeviceInfo + if data, ok := obj.GetAnnotations()[consts.AnnUSBDeviceAddresses]; ok { + err := json.Unmarshal([]byte(data), &usbDeviceInfos) + if err != nil { + return "", nil, fmt.Errorf("failed to unmarshal annotation %s: %w", consts.AnnUSBDeviceAddresses, err) + } + return data, usbDeviceInfos, nil + } + return "", nil, nil +} + +func newContainerEditsOptionsForUSBGateway(deviceName string, usbDevice *Device) containerEditsOptions { + return containerEditsOptions{ + Name: deviceName, + DevicePath: usbDevice.DevicePath, + DeviceNum: usbDevice.DeviceNumber.String(), + Bus: usbDevice.Bus.String(), + Major: int64(usbDevice.Major), + Minor: int64(usbDevice.Minor), + } +} + +func (c containerEditsOptions) withUserGroup(claim *resourcev1.ResourceClaim) containerEditsOptions { + if anno := claim.GetAnnotations()[consts.AnnUSBDeviceUser]; anno != "" { + uid, err := strconv.ParseUint(anno, 10, 32) + if err != nil { + slog.Warn("Failed to parse annotation", slog.String("annotation", consts.AnnUSBDeviceUser), slog.String("value", anno), slog.Any("error", err)) + } else { + c.UID = ptr.To(uint32(uid)) + } + } + if claim.GetAnnotations()[consts.AnnUSBDeviceGroup] != "" { + gid, err := strconv.ParseUint(claim.GetAnnotations()[consts.AnnUSBDeviceGroup], 10, 32) + if err != nil { + slog.Warn("Failed to parse annotation", slog.String("annotation", consts.AnnUSBDeviceGroup), slog.String("value", claim.GetAnnotations()[consts.AnnUSBDeviceGroup]), slog.Any("error", err)) + } + c.GID = ptr.To(uint32(gid)) + } + return c +} + func newContainerEditsOptions(device *resourcev1.Device) (containerEditsOptions, error) { opts := containerEditsOptions{ Name: device.Name, @@ -238,6 +425,12 @@ func newContainerEditsOptions(device *resourcev1.Device) (containerEditsOptions, return opts, nil } +func (s *AllocationStore) isUSBGatewayRequest(result *resourcev1.DeviceRequestAllocationResult) bool { + // virtualization-dra creates slices with pool name by node name + // if pool not equal our node name, it is usb gateway request + return result.Pool != s.nodeName +} + type containerEditsOptions struct { Name string DevicePath string @@ -245,6 +438,8 @@ type containerEditsOptions struct { Bus string Major int64 Minor int64 + UID *uint32 + GID *uint32 } func (s *AllocationStore) makeContainerEdits(claimUID string, opts containerEditsOptions) *cdiapi.ContainerEdits { @@ -268,8 +463,8 @@ func (s *AllocationStore) makeContainerEdits(claimUID string, opts containerEdit Major: opts.Major, Minor: opts.Minor, Permissions: "mrw", - UID: ptr.To(uint32(107)), // qemu user. TODO: make this configurable - GID: ptr.To(uint32(107)), // qemu group. TODO: make this configurable + UID: opts.UID, + GID: opts.GID, }, }, }, @@ -286,8 +481,21 @@ func (s *AllocationStore) Unprepare(_ context.Context, claimUID types.UID) error return fmt.Errorf("unable to delete CDI spec file for claim: %w", err) } + usbGatewayEnabled := featuregates.Default().USBGatewayEnabled() + allocatedDevices := s.resourceClaimAllocations[claimUID] for _, device := range allocatedDevices { + if usbGatewayEnabled { + if count := s.usbipAllocatedDevicesCount[device]; count == 1 { + if err := s.usbGateway.Detach(device); err != nil { + return fmt.Errorf("failed to detach device %s: %w", device, err) + } + delete(s.usbipAllocatedDevicesCount, device) + } else if count > 1 { + s.usbipAllocatedDevicesCount[device]-- + } + } + s.allocatedDevices.Delete(device) } delete(s.resourceClaimAllocations, claimUID) @@ -304,6 +512,15 @@ func (s *AllocationStore) Synchronize(_ context.Context, pods []*api.PodSandbox, containersByPodSandboxID[ctr.PodSandboxId] = append(containersByPodSandboxID[ctr.PodSandboxId], ctr) } + var uspIPDeviceNames map[string]struct{} + if featuregates.Default().USBGatewayEnabled() { + names, err := s.usbGateway.GetAttachedDeviceNames() + if err != nil { + return nil, fmt.Errorf("failed to get attached device names: %w", err) + } + uspIPDeviceNames = names + } + for _, pod := range pods { s.log.Info("Synchronize pod", slog.String("name", pod.Name), slog.String("namespace", pod.Namespace)) ctrs := containersByPodSandboxID[pod.Id] @@ -318,6 +535,10 @@ func (s *AllocationStore) Synchronize(_ context.Context, pods []*api.PodSandbox, s.resourceClaimAllocations[claimUID] = append(s.resourceClaimAllocations[claimUID], deviceNames...) for _, deviceName := range deviceNames { s.allocatedDevices.Insert(deviceName) + + if _, ok := uspIPDeviceNames[deviceName]; ok { + s.usbipAllocatedDevicesCount[deviceName]++ + } } } } @@ -362,6 +583,10 @@ func (s *AllocationStore) makeResources(devices []resourcev1.Device) resourcesli }, } + if featuregates.Default().USBGatewayEnabled() { + pool.NodeSelector = getNodeSelector() + } + return resourceslice.DriverResources{ Pools: map[string]resourceslice.Pool{ poolName: pool, diff --git a/images/virtualization-dra/pkg/controller/controller.go b/images/virtualization-dra/pkg/controller/controller.go new file mode 100644 index 0000000000..7fd01b3d9d --- /dev/null +++ b/images/virtualization-dra/pkg/controller/controller.go @@ -0,0 +1,117 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "fmt" + "log/slog" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" +) + +func ObjectKeyFunc(obj interface{}) (string, error) { + return cache.DeletionHandlingMetaNamespaceKeyFunc(obj) +} + +func MetaObjectKeyFunc(obj metav1.Object) string { + return KeyFunc(obj.GetNamespace(), obj.GetName()) +} + +func KeyFunc(namespace, name string) string { + return types.NamespacedName{ + Namespace: namespace, + Name: name, + }.String() +} + +type Controller interface { + Queue() workqueue.TypedRateLimitingInterface[string] + HasSynced() bool + Sync(ctx context.Context, key string) error + Logger() *slog.Logger +} + +func Run(controller Controller, ctx context.Context, workers int) error { + return newController(controller).Run(ctx, workers) +} + +func newController(c Controller) *controller { + return &controller{ + controller: c, + queue: c.Queue(), + log: c.Logger(), + } +} + +type controller struct { + controller Controller + queue workqueue.TypedRateLimitingInterface[string] + log *slog.Logger +} + +func (c *controller) Run(ctx context.Context, workers int) error { + defer utilruntime.HandleCrash() + defer c.queue.ShutDown() + + c.log.Info("Starting controller") + defer c.log.Info("Shutting down controller") + + if !cache.WaitForCacheSync(ctx.Done(), c.controller.HasSynced) { + return fmt.Errorf("failed to wait for caches to sync") + } + + c.log.Info("Starting workers controller") + for i := 0; i < workers; i++ { + go wait.UntilWithContext(ctx, c.worker, time.Second) + } + + <-ctx.Done() + return nil +} + +func (c *controller) worker(ctx context.Context) { + workFunc := func(ctx context.Context) bool { + key, quit := c.queue.Get() + if quit { + return true + } + defer c.queue.Done(key) + + if err := c.controller.Sync(ctx, key); err != nil { + c.log.Error("re-enqueuing", slog.String("key", key), slog.Any("err", err)) + c.queue.AddRateLimited(key) + } else { + c.log.Info(fmt.Sprintf("processed %v", key)) + c.queue.Forget(key) + } + return false + } + for { + quit := workFunc(ctx) + + if quit { + return + } + } +} diff --git a/images/virtualization-dra/pkg/libusb/device-store.go b/images/virtualization-dra/pkg/libusb/device-store.go index c7e140e3e6..7e2f9e3c70 100644 --- a/images/virtualization-dra/pkg/libusb/device-store.go +++ b/images/virtualization-dra/pkg/libusb/device-store.go @@ -103,10 +103,8 @@ func (s *USBDeviceStore) GetDeviceByBusID(busID string) (*USBDevice, bool) { return nil, false } -func (s *USBDeviceStore) sendChange() { - s.mu.RLock() +func (s *USBDeviceStore) unlockedSendChange() { ch := s.changesCh - s.mu.RUnlock() if ch != nil { s.log.Debug("Notifying USB device store") select { @@ -120,6 +118,8 @@ func (s *USBDeviceStore) sendChange() { // AddDevice adds or updates a device and notifies if changed. func (s *USBDeviceStore) AddDevice(path string, device *USBDevice) bool { s.mu.Lock() + defer s.mu.Unlock() + oldDevice, exists := s.devices[path] needNotify := false if !exists || !device.Equal(oldDevice) { @@ -133,10 +133,9 @@ func (s *USBDeviceStore) AddDevice(path string, device *USBDevice) bool { ) needNotify = true } - s.mu.Unlock() if needNotify { - s.sendChange() + s.unlockedSendChange() } return needNotify } @@ -144,6 +143,8 @@ func (s *USBDeviceStore) AddDevice(path string, device *USBDevice) bool { // RemoveDevice removes a device and notifies if it existed. func (s *USBDeviceStore) RemoveDevice(path string) bool { s.mu.Lock() + defer s.mu.Unlock() + device, exists := s.devices[path] needNotify := false if exists { @@ -154,10 +155,9 @@ func (s *USBDeviceStore) RemoveDevice(path string) bool { delete(s.devices, path) needNotify = true } - s.mu.Unlock() if needNotify { - s.sendChange() + s.unlockedSendChange() } return needNotify } @@ -165,6 +165,8 @@ func (s *USBDeviceStore) RemoveDevice(path string) bool { // Resync synchronizes the store with discovered devices and notifies if changed. func (s *USBDeviceStore) Resync(devices map[string]*USBDevice) bool { s.mu.Lock() + defer s.mu.Unlock() + changed := false // Check for removed devices @@ -193,10 +195,9 @@ func (s *USBDeviceStore) Resync(devices map[string]*USBDevice) bool { changed = true } } - s.mu.Unlock() if changed { - s.sendChange() + s.unlockedSendChange() } return changed diff --git a/images/virtualization-dra/pkg/libusb/discovery.go b/images/virtualization-dra/pkg/libusb/discovery.go index d422c8aba3..dd0277b9c8 100644 --- a/images/virtualization-dra/pkg/libusb/discovery.go +++ b/images/virtualization-dra/pkg/libusb/discovery.go @@ -33,13 +33,15 @@ func DiscoverPluggedUSBDevices() (map[string]*USBDevice, error) { } for _, entry := range entries { - if !entry.IsDir() { + path := filepath.Join(pathToUSBDevices, entry.Name()) + + if entry.Type()&os.ModeSymlink == 0 { + slog.Debug("Skipping non-symlink entry", slog.String("path", path), slog.String("type", entry.Type().String())) continue } - path := filepath.Join(pathToUSBDevices, entry.Name()) - if !isUsbPath(path) { + slog.Debug("Skipping non-usb path", slog.String("path", path)) continue } @@ -54,6 +56,7 @@ func DiscoverPluggedUSBDevices() (map[string]*USBDevice, error) { continue } + slog.Debug("Discovered usb device", slog.String("path", path)) devices[path] = &device } diff --git a/images/virtualization-dra/pkg/libusb/udev_monitor.go b/images/virtualization-dra/pkg/libusb/udev_monitor.go index 8ee2047424..0aee2094f4 100644 --- a/images/virtualization-dra/pkg/libusb/udev_monitor.go +++ b/images/virtualization-dra/pkg/libusb/udev_monitor.go @@ -18,6 +18,7 @@ package libusb import ( "context" + "fmt" "log/slog" "os" "path/filepath" @@ -89,13 +90,12 @@ type udevMonitor interface { // NewUdevMonitor creates a new USB monitor that uses the udev package func NewUdevMonitor(ctx context.Context, opts ...UdevMonitorOption) (Monitor, error) { + log := slog.With(slog.String("component", "udev-usb-monitor")) devices, err := DiscoverPluggedUSBDevices() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to discover USB devices during resync: %w", err) } - log := slog.With(slog.String("component", "udev-usb-monitor")) - m := &UdevMonitor{ store: NewUSBDeviceStore(devices, log), log: log, @@ -231,7 +231,7 @@ func (m *UdevMonitor) handleDeviceUpdate(path string) { return } - slog.Debug("Load usb device", slog.String("path", path)) + m.log.Debug("Load usb device", slog.String("path", path)) device, err := LoadUSBDevice(path) if err != nil { @@ -239,24 +239,56 @@ func (m *UdevMonitor) handleDeviceUpdate(path string) { return } - slog.Debug("Validate usb device", slog.String("path", path)) + m.log.Debug("Validate usb device", slog.String("path", path)) if err := device.Validate(); err != nil { m.log.Debug("device validation failed", slog.String("path", path), slog.String("error", err.Error())) return } - slog.Debug("Add usb device", slog.String("path", path)) + m.log.Debug("Add usb device", slog.String("path", path)) m.store.AddDevice(path, &device) } func (m *UdevMonitor) handleDeviceRemove(path string) { - slog.Debug("Remove usb device", slog.String("path", path)) - m.store.RemoveDevice(path) + // Small delay for sysfs to be fully populated + time.Sleep(50 * time.Millisecond) + + // Check if path exists + if _, err := os.Stat(path); os.IsNotExist(err) { + m.log.Debug("Remove usb device", slog.String("path", path)) + m.store.RemoveDevice(path) + return + } + + if !isUsbPath(path) { + return + } + + device, err := LoadUSBDevice(path) + if err != nil { + m.log.Debug("Remove usb device", slog.String("path", path)) + m.log.Debug("failed to load device", slog.String("path", path), slog.String("error", err.Error())) + m.store.RemoveDevice(path) + return + } + + if err := device.Validate(); err != nil { + m.log.Debug("Remove usb device", slog.String("path", path)) + m.log.Debug("device validation failed", slog.String("path", path), slog.String("error", err.Error())) + m.store.RemoveDevice(path) + return + } + + m.log.Debug("Add usb device", slog.String("path", path)) + + // update usb device + m.store.AddDevice(path, &device) } func (m *UdevMonitor) resync() { + m.log.Info("Resync usb devices") devices, err := DiscoverPluggedUSBDevices() if err != nil { m.log.Error("failed to discover USB devices during resync", slog.String("error", err.Error())) diff --git a/images/virtualization-dra/pkg/modprobe/modprobe.go b/images/virtualization-dra/pkg/modprobe/modprobe.go new file mode 100644 index 0000000000..bcbf563817 --- /dev/null +++ b/images/virtualization-dra/pkg/modprobe/modprobe.go @@ -0,0 +1,134 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package modprobe + +import ( + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/klauspost/compress/zstd" + "golang.org/x/sys/unix" +) + +func LoadModules(modules ...string) error { + for _, module := range modules { + if err := loadModule(module); err != nil { + return fmt.Errorf("load module %s: %w", module, err) + } + } + + return nil +} + +func loadModule(path string) error { + if strings.HasSuffix(path, ".zst") { + uncompressedPath, err := uncompressModuleToTmp(path) + if err != nil { + return fmt.Errorf("uncompress module %s: %w", path, err) + } + defer func() { + if err := os.Remove(uncompressedPath); err != nil { + slog.Error("remove uncompressed module", "path", uncompressedPath, "err", err) + } + }() + path = uncompressedPath + } + + f, err := os.Open(path) + if err != nil { + return fmt.Errorf("open %s: %w", path, err) + } + defer f.Close() + + if err = unix.FinitModule(int(f.Fd()), "", 0); err != nil { + if errors.Is(err, unix.EEXIST) { + slog.Info("Module already loaded", slog.String("path", path)) + return nil + } + return fmt.Errorf("finit_module %s: %w", path, err) + } + + slog.Info("Module loaded", slog.String("path", path)) + + return nil +} + +func uncompressModuleToTmp(path string) (string, error) { + pattern := filepath.Base(path) + "-*" + uncompress, err := os.CreateTemp("", pattern) + if err != nil { + return "", err + } + defer uncompress.Close() + + in, err := os.Open(path) + if err != nil { + return "", err + } + defer in.Close() + + decoder, err := zstd.NewReader(in) + if err != nil { + return "", err + } + defer decoder.Close() + + if _, err := io.Copy(uncompress, decoder); err != nil { + return "", err + } + + return uncompress.Name(), nil +} + +func KernelRelease() (string, error) { + var uts unix.Utsname + if err := unix.Uname(&uts); err != nil { + return "", fmt.Errorf("uname: %w", err) + } + return unix.ByteSliceToString(uts.Release[:]), nil +} + +func KernelSupportsZst(release string) (bool, error) { + parts := strings.Split(release, ".") + if len(parts) < 2 { + return false, fmt.Errorf("invalid release %s", release) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return false, fmt.Errorf("invalid major %s: %w", parts[0], err) + } + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return false, fmt.Errorf("invalid minor %s: %w", parts[1], err) + } + + // ZST is supported since 5.16 + if major > 5 { + return true, nil + } + if major == 5 && minor >= 16 { + return true, nil + } + return false, nil +} diff --git a/images/virtualization-dra/pkg/patch/patch.go b/images/virtualization-dra/pkg/patch/patch.go new file mode 100644 index 0000000000..f544557b2e --- /dev/null +++ b/images/virtualization-dra/pkg/patch/patch.go @@ -0,0 +1,123 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package patch + +import ( + "encoding/json" + "fmt" + "slices" + "strconv" + "strings" +) + +const ( + PatchReplaceOp = "replace" + PatchAddOp = "add" + PatchRemoveOp = "remove" + PatchTestOp = "test" +) + +type JSONPatch struct { + operations []JSONPatchOperation +} + +type JSONPatchOperation struct { + Op string `json:"op"` + Path string `json:"path"` + Value interface{} `json:"value,omitempty"` +} + +func NewJSONPatch(patches ...JSONPatchOperation) *JSONPatch { + return &JSONPatch{ + operations: patches, + } +} + +func NewJSONPatchOperation(op, path string, value interface{}) JSONPatchOperation { + return JSONPatchOperation{ + Op: op, + Path: path, + Value: value, + } +} + +func WithAdd(path string, value interface{}) JSONPatchOperation { + return NewJSONPatchOperation(PatchAddOp, path, value) +} + +func WithRemove(path string) JSONPatchOperation { + return NewJSONPatchOperation(PatchRemoveOp, path, nil) +} + +func WithReplace(path string, value interface{}) JSONPatchOperation { + return NewJSONPatchOperation(PatchReplaceOp, path, value) +} + +func WithTest(path string, value interface{}) JSONPatchOperation { + return NewJSONPatchOperation(PatchTestOp, path, value) +} + +func (jp *JSONPatch) Operations() []JSONPatchOperation { + return jp.operations +} + +func (jp *JSONPatch) Append(patches ...JSONPatchOperation) { + jp.operations = append(jp.operations, patches...) +} + +func (jp *JSONPatch) Delete(op, path string) { + jp.operations = slices.DeleteFunc(jp.operations, func(o JSONPatchOperation) bool { + return o.Op == op && o.Path == path + }) +} + +func (jp *JSONPatch) Len() int { + return len(jp.operations) +} + +func (jp *JSONPatch) String() (string, error) { + bytes, err := jp.Bytes() + if err != nil { + return "", err + } + return string(bytes), nil +} + +func (jp *JSONPatch) Bytes() ([]byte, error) { + if jp.Len() == 0 { + return nil, fmt.Errorf("list of patches is empty") + } + return json.Marshal(jp.operations) +} + +func EscapeJSONPointer(path string) string { + path = strings.ReplaceAll(path, "~", "~0") + path = strings.ReplaceAll(path, "/", "~1") + return path +} + +type AsJsonString struct { + Data interface{} +} + +func (a AsJsonString) MarshalJSON() ([]byte, error) { + b, err := json.Marshal(a.Data) + if err != nil { + return nil, err + } + return []byte(strconv.Quote(string(b))), nil +} diff --git a/images/virtualization-dra/pkg/udev/conn.go b/images/virtualization-dra/pkg/udev/conn.go index 76c289de1c..ce7e20a74d 100644 --- a/images/virtualization-dra/pkg/udev/conn.go +++ b/images/virtualization-dra/pkg/udev/conn.go @@ -18,6 +18,7 @@ package udev import ( "fmt" + "log/slog" "os" "runtime" "syscall" @@ -113,14 +114,22 @@ func (c *Conn) connectInNetNS(mode Mode) error { if err != nil { return fmt.Errorf("failed to open current netns: %w", err) } - defer unix.Close(currentNS) + defer func(fd int) { + if err = unix.Close(fd); err != nil { + slog.Error("failed to close current netns", slog.String("error", err.Error())) + } + }(currentNS) // Open target network namespace targetNS, err := unix.Open(c.netNS, unix.O_RDONLY|unix.O_CLOEXEC, 0) if err != nil { return fmt.Errorf("failed to open target netns %s: %w", c.netNS, err) } - defer unix.Close(targetNS) + defer func(fd int) { + if err = unix.Close(fd); err != nil { + slog.Error("failed to close target netns", slog.String("error", err.Error())) + } + }(targetNS) // Switch to target network namespace if err := unix.Setns(targetNS, unix.CLONE_NEWNET); err != nil { diff --git a/images/virtualization-dra/pkg/usbip/attacher.go b/images/virtualization-dra/pkg/usbip/attacher.go new file mode 100644 index 0000000000..6816e4b01e --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/attacher.go @@ -0,0 +1,316 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "fmt" + "log/slog" + "net" + "os" + "strconv" + "sync" + "syscall" + + "github.com/deckhouse/virtualization-dra/pkg/libusb" + "github.com/deckhouse/virtualization-dra/pkg/usbip/protocol" +) + +func NewUSBAttacher() USBAttacher { + return &usbAttacher{} +} + +type usbAttacher struct { + mu sync.Mutex +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/src/usbip_attach.c#L174 +func (a *usbAttacher) Attach(host, busID string, port int) (int, error) { + a.mu.Lock() + defer a.mu.Unlock() + + conn, err := a.usbipNetTCPConnect(host, strconv.Itoa(port)) + if err != nil { + return -1, fmt.Errorf("failed to connect to usbipd: %w", err) + } + + rhport, err := a.queryImportDevice(conn, busID) + if err != nil { + return -1, fmt.Errorf("failed to query import device: %w", err) + } + + err = a.recordConnection(host, strconv.Itoa(port), busID, rhport) + if err != nil { + return -1, fmt.Errorf("failed to record connection: %w", err) + } + + return rhport, nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/src/usbip_detach.c#L32 +func (a *usbAttacher) Detach(rhport int) error { + a.mu.Lock() + defer a.mu.Unlock() + + driver, err := newVhciDriver() + if err != nil { + return fmt.Errorf("failed to get vhci driver: %w", err) + } + + found := false + for i := 0; i < driver.nports; i++ { + idev := &driver.idevs[i] + + if idev.port == rhport { + found = true + vstatus := protocol.DeviceStatus(idev.status) + if vstatus == protocol.VDeviceStatusNull { + slog.Info("Port is already detached", slog.Int("rhport", rhport)) + return fmt.Errorf("port is already detached") + } + + break + } + } + + if !found { + slog.Error("Invalid port > maxports", slog.Int("rhport", rhport), slog.Int("maxports", driver.nports)) + return fmt.Errorf("rhport %d not found", rhport) + } + + path := vhciStatePortPath(rhport) + + if err = os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove vhci state port file: %w", err) + } + + if err = os.RemoveAll(vhciStatePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove vhci state path: %w", err) + } + + if err = writeSysfsAttr(vhciHcdDetach, detachAttr{port: rhport}); err != nil { + return fmt.Errorf("failed to write detach attribute: %w", err) + } + + slog.Info("Port detached", slog.Int("rhport", rhport)) + return nil +} + +func (a *usbAttacher) GetAttachInfo() ([]AttachInfo, error) { + driver, err := newVhciDriver() + if err != nil { + return nil, fmt.Errorf("failed to get vhci driver: %w", err) + } + + var usedInfos []AttachInfo + + for i := 0; i < driver.nports; i++ { + idev := &driver.idevs[i] + + vstatus := protocol.DeviceStatus(idev.status) + if vstatus == protocol.VDeviceStatusUsed { + usedInfos = append(usedInfos, AttachInfo{ + Port: idev.port, + Busnum: idev.busnum, + Devnum: idev.devnum, + LocalBusID: idev.localBusID, + }) + } + } + + return usedInfos, nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/src/usbip_network.c#L261 +func (a *usbAttacher) usbipNetTCPConnect(host, port string) (*net.TCPConn, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port)) + if err != nil { + return nil, fmt.Errorf("resolve TCP address: %w", err) + } + + conn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + return nil, fmt.Errorf("dial TCP: %w", err) + } + + if err := conn.SetNoDelay(true); err != nil { + if conErr := conn.Close(); conErr != nil { + slog.Error("failed to close connection", slog.String("error", conErr.Error())) + } + return nil, fmt.Errorf("set TCP_NODELAY: %w", err) + } + + if err := conn.SetKeepAlive(true); err != nil { + if conErr := conn.Close(); conErr != nil { + slog.Error("failed to close connection", slog.String("error", conErr.Error())) + } + return nil, fmt.Errorf("set keepalive: %w", err) + } + + return conn, nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/src/usbip_attach.c#L120 +func (a *usbAttacher) queryImportDevice(conn *net.TCPConn, busID string) (int, error) { + opCommon := protocol.NewOpCommon(protocol.OpReqImport, protocol.OpStatusOk) + importReq := protocol.NewImportRequest(busID) + + if err := opCommon.Encode(conn); err != nil { + return -1, fmt.Errorf("failed to encode OpCommon: %w", err) + } + + if err := importReq.Encode(conn); err != nil { + return -1, fmt.Errorf("failed to encode ImportRequest: %w", err) + } + + importReply := &protocol.ImportReply{} + if err := importReply.Decode(conn); err != nil { + return -1, fmt.Errorf("failed to decode ImportReply: %w", err) + } + + if importReply.Version != protocol.Version { + return -1, fmt.Errorf("unsupported USBIP version: %d", importReply.Version) + } + + if importReply.Status != protocol.OpStatusOk { + return -1, fmt.Errorf("reply failed: %s", importReply.Status.String()) + } + + if importReply.GetBusID() != busID { + return -1, fmt.Errorf("busID mismatch: %s != %s", importReply.GetBusID(), busID) + } + + return a.importDevice(conn, importReply.USBDevice) +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/src/usbip_attach.c#L81 +func (a *usbAttacher) importDevice(conn *net.TCPConn, usbDevice protocol.USBDevice) (int, error) { + port, err := a.getFreePort(usbDevice.Speed) + if err != nil { + return -1, fmt.Errorf("failed to get free port: %w", err) + } + + sockFd, err := a.getSockFd(conn) + if err != nil { + return -1, fmt.Errorf("failed to get socket fd: %w", err) + } + + devID := getDevId(usbDevice.Busnum, usbDevice.Devnum) + + attr := attachAttr{ + port: port, + sockFd: sockFd, + devId: devID, + speed: usbDevice.Speed, + } + + err = writeSysfsAttr(vhciHcdAttach, attr) + if err != nil { + return -1, fmt.Errorf("failed to write attach attribute: %w", err) + } + + return port, nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/libsrc/vhci_driver.c#L334 +func (a *usbAttacher) getFreePort(speed uint32) (int, error) { + driver, err := newVhciDriver() + if err != nil { + return -1, err + } + + deviceSpeed := libusb.USBDeviceSpeed(speed) + + for i := 0; i < driver.nports; i++ { + switch deviceSpeed { + case libusb.USBDeviceSpeedSuper: + if driver.idevs[i].hub != hubSpeedSuper { + continue + } + default: + if driver.idevs[i].hub != hubSpeedHigh { + continue + } + } + vstatus := protocol.DeviceStatus(driver.idevs[i].status) + if vstatus == protocol.VDeviceStatusNull { + return driver.idevs[i].port, nil + } + } + + return -1, nil +} + +func (a *usbAttacher) getSockFd(conn *net.TCPConn) (int, error) { + file, err := conn.File() + if err != nil { + return -1, err + } + defer file.Close() + + fd := int(file.Fd()) + + newFd, err := syscall.Dup(fd) + if err != nil { + return -1, err + } + + return newFd, nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/src/usbip_attach.c#L39 +func (a *usbAttacher) recordConnection(host, port, busID string, rhport int) error { + err := os.MkdirAll(vhciStatePath, 0o700) + if err != nil { + return fmt.Errorf("failed to create vhci state path: %w", err) + } + + path := vhciStatePortPath(rhport) + + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o700) + if err != nil { + return fmt.Errorf("failed to open vhci state port file: %w", err) + } + defer file.Close() + + value := fmt.Sprintf("%s %s %s", host, port, busID) + + _, err = file.WriteString(value) + if err != nil { + return fmt.Errorf("failed to write vhci state port file: %w", err) + } + + return nil +} + +type attachAttr struct { + port int + sockFd int + devId int + speed uint32 +} + +func (a attachAttr) Complete() string { + return fmt.Sprintf("%d %d %d %d", a.port, a.sockFd, a.devId, a.speed) +} + +type detachAttr struct { + port int +} + +func (a detachAttr) Complete() string { + return fmt.Sprintf("%d", a.port) +} diff --git a/images/virtualization-dra/pkg/usbip/binder.go b/images/virtualization-dra/pkg/usbip/binder.go new file mode 100644 index 0000000000..b2c68e3d15 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/binder.go @@ -0,0 +1,254 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/deckhouse/virtualization-dra/pkg/libusb" +) + +func NewUSBBinder() USBBinder { + return &usbBinder{} +} + +type usbBinder struct { + mu sync.Mutex +} + +// Bind binds the USB device to the USBIP server. +// https://github.com/torvalds/linux/blob/40fbbd64bba6c6e7a72885d2f59b6a3be9991eeb/tools/usb/usbip/src/usbip_bind.c#L130 +func (b *usbBinder) Bind(busID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + devInfo, err := b.getUSBDeviceInfo(busID) + if err != nil { + return fmt.Errorf("device with bus ID %s does not exist: %w", busID, err) + } + + if strings.Contains(devInfo.DevPath, "vhci_hcd") { + return fmt.Errorf("bind loop detected: device %s is already attached to vhci_hcd", busID) + } + + err = b.unbindOther(devInfo) + if err != nil { + return fmt.Errorf("failed to unbind other devices: %w", err) + } + + if err = b.modifyMatchBusID(busID, true); err != nil { + return err + } + + if err = b.bindUsbip(busID); err != nil { + return fmt.Errorf("failed to bind usb device: %w: %w", err, b.modifyMatchBusID(busID, false)) + } + + return nil +} + +// Unbind unbinds the USB device from the USBIP server. +// https://github.com/torvalds/linux/blob/40fbbd64bba6c6e7a72885d2f59b6a3be9991eeb/tools/usb/usbip/src/usbip_unbind.c#L30 +func (b *usbBinder) Unbind(busID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + devInfo, err := b.getUSBDeviceInfo(busID) + if err != nil { + return fmt.Errorf("device with bus ID %s does not exist: %w", busID, err) + } + + if !b.isBound(devInfo) { + return fmt.Errorf("device %s is not bound to %s driver", devInfo.BusID, usbipHostDriverName) + } + + if err = b.unbindUsbip(busID); err != nil { + return fmt.Errorf("failed to unbind usb device %s: %w", busID, err) + } + + // notify driver of unbind + if err = b.modifyMatchBusID(busID, false); err != nil { + return fmt.Errorf("failed to modify match bus ID %s: %w", busID, err) + } + + // Trigger new probing + if err = b.rebindUsbip(busID); err != nil { + return fmt.Errorf("failed to rebind usb device %s: %w", busID, err) + } + + return nil +} + +func (b *usbBinder) IsBound(busID string) (bool, error) { + devInfo, err := b.getUSBDeviceInfo(busID) + if err != nil { + return false, fmt.Errorf("device with bus ID %s does not exist: %w", busID, err) + } + return b.isBound(devInfo), nil +} + +func (b *usbBinder) GetBindInfo() ([]BindInfo, error) { + usbDevices, err := libusb.DiscoverPluggedUSBDevices() + if err != nil { + return nil, fmt.Errorf("failed to discover USB devices: %w", err) + } + + var infos []BindInfo + + for _, device := range usbDevices { + devInfo := usbDeviceInfo{ + BusID: device.BusID, + Driver: device.Driver, + DevPath: device.DevicePath, + IsHub: device.IsHub, + } + + infos = append(infos, BindInfo{ + DevicePath: device.DevicePath, + BusID: device.BusID, + Busnum: int(device.Bus), + Devnum: int(device.DeviceNumber), + Bound: b.isBound(&devInfo), + }) + } + + return infos, err +} + +type usbDeviceInfo struct { + BusID string + Driver string + DevPath string + IsHub bool +} + +func (b *usbBinder) getUSBDeviceInfo(busID string) (*usbDeviceInfo, error) { + path := getUSBDevicePath(busID) + + if _, err := os.Stat(path); err != nil { + return nil, err + } + + info := &usbDeviceInfo{ + BusID: busID, + } + + bDevClassPath := filepath.Join(path, "bDeviceClass") + data, err := os.ReadFile(bDevClassPath) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", bDevClassPath, err) + } + info.IsHub = strings.TrimSpace(string(data)) == "09" // 09 = USB Hub class + + ueventPath := filepath.Join(path, "uevent") + ueventFile, err := os.Open(ueventPath) + if err != nil { + return nil, fmt.Errorf("unable to open the file %s: %w", ueventPath, err) + } + defer ueventFile.Close() + scanner := bufio.NewScanner(ueventFile) + + count := 0 + for scanner.Scan() { + line := scanner.Text() + values := strings.Split(line, "=") + if len(values) != 2 { + continue + } + switch values[0] { + case "DEVNAME": + info.DevPath = filepath.Join("/dev", values[1]) + count++ + case "DRIVER": + info.Driver = values[1] + count++ + } + if count == 2 { + break + } + } + + return info, nil +} + +func (b *usbBinder) isBound(devInfo *usbDeviceInfo) bool { + return devInfo.Driver == usbipHostDriverName +} + +func (b *usbBinder) unbindOther(devInfo *usbDeviceInfo) error { + if devInfo.IsHub { + return fmt.Errorf("skip unbinding of hub %s", devInfo.BusID) + } + + if devInfo.Driver == "" { + // no driver bound to the device + return nil + } + + if b.isBound(devInfo) { + return fmt.Errorf("device %s is already bound to %s", devInfo.BusID, usbipHostDriverName) + } + + unbindPath := unbindAttrPath(devInfo.Driver) + + if err := writeSysfsAttr(unbindPath, busIDAttr{busID: devInfo.BusID}); err != nil { + return fmt.Errorf("error unbinding device %s from driver %s: %w", devInfo.BusID, devInfo.Driver, err) + } + + return nil +} + +func (b *usbBinder) bindUsbip(busID string) error { + return writeSysfsAttr(bindAttrPath(usbipHostDriverName), busIDAttr{busID: busID}) +} + +func (b *usbBinder) unbindUsbip(busID string) error { + return writeSysfsAttr(unbindAttrPath(usbipHostDriverName), busIDAttr{busID: busID}) +} + +func (b *usbBinder) rebindUsbip(busID string) error { + return writeSysfsAttr(rebindAttrPath(usbipHostDriverName), busIDAttr{busID: busID}) +} + +func (b *usbBinder) modifyMatchBusID(busID string, add bool) error { + return writeSysfsAttr(matchBusIDAttrPath(usbipHostDriverName), modifyMatchBusIDAttr{busID: busID, add: add}) +} + +type modifyMatchBusIDAttr struct { + busID string + add bool +} + +func (a modifyMatchBusIDAttr) Complete() string { + if a.add { + return fmt.Sprintf("add %s", a.busID) + } + return fmt.Sprintf("del %s", a.busID) +} + +type busIDAttr struct { + busID string +} + +func (a busIDAttr) Complete() string { + return a.busID +} diff --git a/images/virtualization-dra/pkg/usbip/exporter.go b/images/virtualization-dra/pkg/usbip/exporter.go new file mode 100644 index 0000000000..9956a0efff --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/exporter.go @@ -0,0 +1,120 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "fmt" + "log/slog" + "net" + "strconv" + + "github.com/deckhouse/virtualization-dra/pkg/usbip/protocol" +) + +func NewUSBExporter() USBExporter { + return &usbExporter{} +} + +type usbExporter struct{} + +func (e *usbExporter) Export(host, busID string, port int) error { + conn, err := e.usbipNetTCPConnect(host, port) + if err != nil { + return fmt.Errorf("failed to connect to usbipd: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + slog.Error("failed to close connection", slog.String("error", err.Error())) + } + }() + + opCommon := protocol.NewOpCommon(protocol.OpReqExport, protocol.OpStatusOk) + if err = opCommon.Encode(conn); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + + exportReq := protocol.NewExportRequest(busID) + if err = exportReq.Encode(conn); err != nil { + return fmt.Errorf("failed to encode ExportRequest: %w", err) + } + + exportReply := &protocol.ExportReply{} + if err = exportReply.Decode(conn); err != nil { + return fmt.Errorf("failed to decode ExportReply: %w", err) + } + + if exportReply.Version != protocol.Version { + return fmt.Errorf("unsupported USBIP version: %d", exportReply.Version) + } + + if exportReply.Status != protocol.OpStatusOk { + return fmt.Errorf("reply failed: %s", exportReply.Status.String()) + } + + return nil +} + +func (e *usbExporter) Unexport(host, busID string, port int) error { + conn, err := e.usbipNetTCPConnect(host, port) + if err != nil { + return fmt.Errorf("failed to connect to usbipd: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + slog.Error("failed to close connection", slog.String("error", err.Error())) + } + }() + + opCommon := protocol.NewOpCommon(protocol.OpReqUnexport, protocol.OpStatusOk) + if err = opCommon.Encode(conn); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + + unExportReq := protocol.NewUnExportRequest(busID) + if err = unExportReq.Encode(conn); err != nil { + return fmt.Errorf("failed to encode UnExportRequest: %w", err) + } + + unExportReply := &protocol.UnExportReply{} + if err = unExportReply.Decode(conn); err != nil { + return fmt.Errorf("failed to decode UnExportReply: %w", err) + } + + if unExportReply.Version != protocol.Version { + return fmt.Errorf("unsupported USBIP version: %d", unExportReply.Version) + } + + if unExportReply.Status != protocol.OpStatusOk { + return fmt.Errorf("reply failed: %s", unExportReply.Status.String()) + } + + return nil +} + +func (e *usbExporter) usbipNetTCPConnect(host string, port int) (*net.TCPConn, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return nil, fmt.Errorf("resolve TCP address: %w", err) + } + + conn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + return nil, fmt.Errorf("dial TCP: %w", err) + } + + return conn, nil +} diff --git a/images/virtualization-dra/pkg/usbip/interfaces.go b/images/virtualization-dra/pkg/usbip/interfaces.go new file mode 100644 index 0000000000..81f23a8f09 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/interfaces.go @@ -0,0 +1,110 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +type Interface interface { + ServerInterface + ClientInterface +} + +type ServerInterface interface { + USBBinder +} + +type ClientInterface interface { + USBAttacher + USBExporter +} + +type USBBinder interface { + Bind(busID string) error + Unbind(busID string) error + IsBound(busID string) (bool, error) + BindInfoGetter +} + +type BindInfoGetter interface { + GetBindInfo() ([]BindInfo, error) +} + +type BindInfo struct { + DevicePath string + BusID string + Busnum int + Devnum int + Bound bool +} + +type USBAttacher interface { + Attach(host, busID string, port int) (int, error) + Detach(rhport int) error + AttachInfoGetter +} + +type AttachInfoGetter interface { + GetAttachInfo() ([]AttachInfo, error) +} + +type AttachInfo struct { + Port, Busnum, Devnum int + LocalBusID string +} + +type USBExporter interface { + Export(host, busID string, port int) error + Unexport(host, busID string, port int) error +} + +type serverImpl struct { + USBBinder +} + +func NewServer(binder USBBinder) ServerInterface { + return &serverImpl{USBBinder: binder} +} + +type clientImpl struct { + USBAttacher + USBExporter +} + +func NewClient(attacher USBAttacher, exporter USBExporter) ClientInterface { + return &clientImpl{USBAttacher: attacher, USBExporter: exporter} +} + +type interfaceImpl struct { + ServerInterface + ClientInterface +} + +func NewInterface(server ServerInterface, client ClientInterface) Interface { + return &interfaceImpl{ + ServerInterface: server, + ClientInterface: client, + } +} + +func New() Interface { + binder := NewUSBBinder() + attacher := NewUSBAttacher() + exporter := NewUSBExporter() + + server := NewServer(binder) + client := NewClient(attacher, exporter) + + return NewInterface(server, client) +} diff --git a/images/virtualization-dra/pkg/usbip/protocol/common.go b/images/virtualization-dra/pkg/usbip/protocol/common.go new file mode 100644 index 0000000000..f3e6265275 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/protocol/common.go @@ -0,0 +1,165 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// https://github.com/torvalds/linux/blob/master/tools/usb/usbip/src/usbip_network.h +package protocol + +import ( + "encoding/binary" + "fmt" + "io" +) + +type USBVersion uint16 + +const ( + Version USBVersion = 0x0111 +) + +type Op uint16 + +// Common header for all the kinds of PDUs. +const ( + OpRequest Op = 0x80 << 8 + OpReply Op = 0x00 << 8 +) + +// Dummy Code +const ( + OpUnspec Op = 0x00 + OpReqUnspec Op = OpRequest | OpUnspec + OpRepUnspec Op = OpReply | OpUnspec +) + +// Retrieve USB device information. (still not used) +const ( + OpDevInfo Op = 0x02 + OpReqDevInfo Op = OpRequest | OpDevInfo + OpRepDevInfo Op = OpReply | OpDevInfo +) + +// Import a remote USB device. +const ( + OpImport Op = 0x03 + OpReqImport Op = OpRequest | OpImport + OpRepImport Op = OpReply | OpImport +) + +// Negotiate IPSec encryption key. (still not used) +const ( + OpCrypkey Op = 0x04 + OpReqCrypkey Op = OpRequest | OpCrypkey + OpRepCrypkey Op = OpReply | OpCrypkey +) + +// Retrieve the list of exported USB devices. +const ( + OpDevList Op = 0x05 + OpReqDevList Op = OpRequest | OpDevList + OpRepDevList Op = OpReply | OpDevList +) + +// Export a USB device to a remote host. +const ( + OpExport Op = 0x06 + OpReqExport Op = OpRequest | OpExport + OpRepExport Op = OpReply | OpExport +) + +// un-Export a USB device from a remote host. +const ( + OpUnexport Op = 0x07 + OpReqUnexport Op = OpRequest | OpUnexport + OpRepUnexport Op = OpReply | OpUnexport +) + +type OpStatus uint32 + +const ( + OpStatusOk OpStatus = 0x00 + OpStatusNA OpStatus = 0x01 + OpStatusDevBusy OpStatus = 0x02 + OpStatusDevErr OpStatus = 0x03 + OpStatusNoDev OpStatus = 0x04 + OpStatusError OpStatus = 0x05 +) + +func (o OpStatus) String() string { + switch o { + case OpStatusOk: + return "OK" + case OpStatusNA: + return "NA" + case OpStatusDevBusy: + return "DevBusy" + case OpStatusDevErr: + return "DevErr" + case OpStatusNoDev: + return "NoDev" + case OpStatusError: + return "Error" + default: + return "Unknown" + } +} + +type DeviceStatus uint32 + +const ( + DeviceStatusAvailable DeviceStatus = iota + 0x01 + DeviceStatusUsed + DeviceStatusError + VDeviceStatusNull + VDeviceStatusNotAssigned + VDeviceStatusUsed + VDeviceStatusError +) + +func NewOpCommon(code Op, status OpStatus) *OpCommon { + return &OpCommon{ + Version: Version, + Code: code, + Status: status, + } +} + +type OpCommon struct { + Version USBVersion + Code Op + Status OpStatus +} + +func (op *OpCommon) Decode(r io.Reader) error { + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read OpCommon: %w", err) + } + + op.Version = USBVersion(binary.BigEndian.Uint16(buf[0:2])) + op.Code = Op(binary.BigEndian.Uint16(buf[2:4])) + op.Status = OpStatus(binary.BigEndian.Uint32(buf[4:8])) + return nil +} + +func (op *OpCommon) Encode(w io.Writer) error { + buf := make([]byte, 8) + binary.BigEndian.PutUint16(buf[0:2], uint16(op.Version)) + binary.BigEndian.PutUint16(buf[2:4], uint16(op.Code)) + binary.BigEndian.PutUint32(buf[4:8], uint32(op.Status)) + _, err := w.Write(buf) + return err +} diff --git a/images/virtualization-dra/pkg/usbip/protocol/convert.go b/images/virtualization-dra/pkg/usbip/protocol/convert.go new file mode 100644 index 0000000000..3594f8f0ce --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/protocol/convert.go @@ -0,0 +1,52 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import "bytes" + +func ToDevicePath(path string) [256]byte { + var result [256]byte + writeCString(result[:], path) + return result +} + +func ToBusID(busID string) [32]byte { + var result [32]byte + writeCString(result[:], busID) + return result +} + +func fromCString(buf []byte) string { + newBytes := buf + if ib := bytes.IndexByte(newBytes, 0); ib != -1 { + newBytes = newBytes[:ib] + } + return string(newBytes) +} + +func writeCString(dst []byte, s string) { + for i := range dst { + dst[i] = 0 + } + + n := len(s) + if n >= len(dst) { + n = len(dst) - 1 + } + + copy(dst[:n], s) +} diff --git a/images/virtualization-dra/pkg/usbip/protocol/device_list.go b/images/virtualization-dra/pkg/usbip/protocol/device_list.go new file mode 100644 index 0000000000..2f25a56fff --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/protocol/device_list.go @@ -0,0 +1,249 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import ( + "encoding/binary" + "fmt" + "io" +) + +func NewDeviceList(status OpStatus, devices []USBDeviceInfo) *DeviceList { + return &DeviceList{ + OpCommon: OpCommon{ + Version: Version, + Code: OpReqDevList, + Status: status, + }, + Ndev: uint32(len(devices)), + Devices: devices, + } +} + +type DeviceList struct { + OpCommon + + Ndev uint32 + Devices []USBDeviceInfo +} + +func (d *DeviceList) Encode(w io.Writer) error { + if err := d.OpCommon.Encode(w); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf[0:4], d.Ndev) + + if _, err := w.Write(buf); err != nil { + return fmt.Errorf("failed to write Ndev to writer: %w", err) + } + + for _, dev := range d.Devices { + if err := dev.Encode(w); err != nil { + return fmt.Errorf("failed to encode USBDeviceInfo: %w", err) + } + } + + return nil +} + +const ( + sysfsPathMax = 256 + sysfsBusIdMax = 32 +) + +type USBDeviceInfo struct { + USBDevice + Interfaces []USBDeviceInterface +} + +func (d *USBDeviceInfo) Decode(r io.Reader) error { + if err := d.USBDevice.Decode(r); err != nil { + return fmt.Errorf("unable to decode USBDevice: %w", err) + } + + d.Interfaces = make([]USBDeviceInterface, d.BNumInterfaces) + for i := 0; i < int(d.BNumInterfaces); i++ { + if err := d.Interfaces[i].Decode(r); err != nil { + return fmt.Errorf("unable to decode USBDeviceInterface: %w", err) + } + } + + return nil +} + +func (d *USBDeviceInfo) Encode(w io.Writer) error { + if err := d.USBDevice.Encode(w); err != nil { + return fmt.Errorf("unable to encode USBDevice: %w", err) + } + + for _, iface := range d.Interfaces { + if err := iface.Encode(w); err != nil { + return fmt.Errorf("unable to encode USBDeviceInterface: %w", err) + } + } + + return nil +} + +type USBDevice struct { + Path [sysfsPathMax]byte + BusID [sysfsBusIdMax]byte + + Busnum uint32 + Devnum uint32 + Speed uint32 + + IDVendor uint16 + IDProduct uint16 + BcdDevice uint16 + + BDeviceClass uint8 + BDeviceSubClass uint8 + BDeviceProtocol uint8 + BConfigurationValue uint8 + BNumConfigurations uint8 + BNumInterfaces uint8 +} + +func (u *USBDevice) GetPath() string { + return fromCString(u.Path[:]) +} + +func (u *USBDevice) GetBusID() string { + return fromCString(u.BusID[:]) +} + +func (u *USBDevice) Decode(r io.Reader) error { + buf := make([]byte, sysfsPathMax+sysfsBusIdMax+12+6+6) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read USBDevice from reader: %w", err) + } + + copy(u.Path[:], buf[0:sysfsPathMax]) + copy(u.BusID[:], buf[sysfsPathMax:sysfsPathMax+sysfsBusIdMax]) + + pass := sysfsPathMax + sysfsBusIdMax + + u.Busnum = binary.BigEndian.Uint32(buf[pass : pass+4]) + pass += 4 + u.Devnum = binary.BigEndian.Uint32(buf[pass : pass+4]) + pass += 4 + u.Speed = binary.BigEndian.Uint32(buf[pass : pass+4]) + pass += 4 + + u.IDVendor = binary.BigEndian.Uint16(buf[pass : pass+2]) + pass += 2 + u.IDProduct = binary.BigEndian.Uint16(buf[pass : pass+2]) + pass += 2 + u.BcdDevice = binary.BigEndian.Uint16(buf[pass : pass+2]) + pass += 2 + + u.BDeviceClass = buf[pass] + pass += 1 + u.BDeviceSubClass = buf[pass] + pass += 1 + u.BDeviceProtocol = buf[pass] + pass += 1 + u.BConfigurationValue = buf[pass] + pass += 1 + u.BNumConfigurations = buf[pass] + pass += 1 + u.BNumInterfaces = buf[pass] + + return nil +} + +func (u *USBDevice) Encode(w io.Writer) error { + buf := make([]byte, sysfsPathMax+sysfsBusIdMax+12+6+6) + + copy(buf[0:sysfsPathMax], u.Path[:]) + copy(buf[sysfsPathMax:sysfsPathMax+sysfsBusIdMax], u.BusID[:]) + + pass := sysfsPathMax + sysfsBusIdMax + + binary.BigEndian.PutUint32(buf[pass:pass+4], u.Busnum) + pass += 4 + binary.BigEndian.PutUint32(buf[pass:pass+4], u.Devnum) + pass += 4 + binary.BigEndian.PutUint32(buf[pass:pass+4], u.Speed) + pass += 4 + + binary.BigEndian.PutUint16(buf[pass:pass+2], u.IDVendor) + pass += 2 + binary.BigEndian.PutUint16(buf[pass:pass+2], u.IDProduct) + pass += 2 + binary.BigEndian.PutUint16(buf[pass:pass+2], u.BcdDevice) + pass += 2 + + buf[pass] = u.BDeviceClass + pass += 1 + buf[pass] = u.BDeviceSubClass + pass += 1 + buf[pass] = u.BDeviceProtocol + pass += 1 + buf[pass] = u.BConfigurationValue + pass += 1 + buf[pass] = u.BNumConfigurations + pass += 1 + buf[pass] = u.BNumInterfaces + + _, err := w.Write(buf) + if err != nil { + return fmt.Errorf("failed to write USBDevice to writer: %w", err) + } + return nil +} + +type USBDeviceInterface struct { + BInterfaceClass uint8 + BInterfaceSubClass uint8 + BInterfaceProtocol uint8 + padding uint8 +} + +func (u *USBDeviceInterface) Decode(r io.Reader) error { + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read USBDeviceInterface from reader: %w", err) + } + + u.BInterfaceClass = buf[0] + u.BInterfaceSubClass = buf[1] + u.BInterfaceProtocol = buf[2] + u.padding = buf[3] + + return nil +} + +func (u *USBDeviceInterface) Encode(w io.Writer) error { + buf := make([]byte, 4) + + buf[0] = u.BInterfaceClass + buf[1] = u.BInterfaceSubClass + buf[2] = u.BInterfaceProtocol + buf[3] = u.padding + + _, err := w.Write(buf) + if err != nil { + return fmt.Errorf("failed to write USBDeviceInterface to writer: %w", err) + } + return nil +} diff --git a/images/virtualization-dra/pkg/usbip/protocol/export.go b/images/virtualization-dra/pkg/usbip/protocol/export.go new file mode 100644 index 0000000000..bab69b1463 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/protocol/export.go @@ -0,0 +1,80 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import ( + "fmt" + "io" +) + +func NewExportRequest(busID string) *ExportRequest { + return &ExportRequest{ + busID: ToBusID(busID), + } +} + +type ExportRequest struct { + busID [sysfsBusIdMax]byte +} + +func (i *ExportRequest) BusID() string { + return fromCString(i.busID[:]) +} + +func (i *ExportRequest) Encode(w io.Writer) error { + _, err := w.Write(i.busID[:]) + return err +} + +func (i *ExportRequest) Decode(r io.Reader) error { + buf := make([]byte, sysfsBusIdMax) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read ExportRequest from reader: %w", err) + } + + copy(i.busID[:], buf) + return nil +} + +type ExportReply struct { + OpCommon +} + +func NewExportReply(status OpStatus) *ExportReply { + return &ExportReply{ + OpCommon: OpCommon{ + Version: Version, + Code: OpRepExport, + Status: status, + }, + } +} + +func (i *ExportReply) Encode(w io.Writer) error { + if err := i.OpCommon.Encode(w); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + return nil +} + +func (i *ExportReply) Decode(r io.Reader) error { + if err := i.OpCommon.Decode(r); err != nil { + return fmt.Errorf("failed to decode OpCommon: %w", err) + } + return nil +} diff --git a/images/virtualization-dra/pkg/usbip/protocol/import.go b/images/virtualization-dra/pkg/usbip/protocol/import.go new file mode 100644 index 0000000000..783e659936 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/protocol/import.go @@ -0,0 +1,88 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import ( + "fmt" + "io" +) + +func NewImportRequest(busID string) *ImportRequest { + return &ImportRequest{ + busID: ToBusID(busID), + } +} + +type ImportRequest struct { + busID [sysfsBusIdMax]byte +} + +func (i *ImportRequest) BusID() string { + return fromCString(i.busID[:]) +} + +func (i *ImportRequest) Encode(w io.Writer) error { + _, err := w.Write(i.busID[:]) + return err +} + +func (i *ImportRequest) Decode(r io.Reader) error { + buf := make([]byte, sysfsBusIdMax) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read ImportRequest from reader: %w", err) + } + + copy(i.busID[:], buf) + return nil +} + +type ImportReply struct { + OpCommon + USBDevice +} + +func NewImportReply(status OpStatus, device USBDevice) *ImportReply { + return &ImportReply{ + OpCommon: OpCommon{ + Version: Version, + Code: OpRepImport, + Status: status, + }, + USBDevice: device, + } +} + +func (i *ImportReply) Encode(w io.Writer) error { + if err := i.OpCommon.Encode(w); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + if err := i.USBDevice.Encode(w); err != nil { + return fmt.Errorf("failed to encode USBDevice: %w", err) + } + return nil +} + +func (i *ImportReply) Decode(r io.Reader) error { + if err := i.OpCommon.Decode(r); err != nil { + return fmt.Errorf("failed to decode OpCommon: %w", err) + } + if err := i.USBDevice.Decode(r); err != nil { + return fmt.Errorf("failed to decode USBDevice: %w", err) + } + return nil +} diff --git a/images/virtualization-dra/pkg/usbip/protocol/unexport.go b/images/virtualization-dra/pkg/usbip/protocol/unexport.go new file mode 100644 index 0000000000..40edbeb89b --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/protocol/unexport.go @@ -0,0 +1,80 @@ +/* +Copyright 2026 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package protocol + +import ( + "fmt" + "io" +) + +func NewUnExportRequest(busID string) *UnExportRequest { + return &UnExportRequest{ + busID: ToBusID(busID), + } +} + +type UnExportRequest struct { + busID [sysfsBusIdMax]byte +} + +func (i *UnExportRequest) BusID() string { + return fromCString(i.busID[:]) +} + +func (i *UnExportRequest) Encode(w io.Writer) error { + _, err := w.Write(i.busID[:]) + return err +} + +func (i *UnExportRequest) Decode(r io.Reader) error { + buf := make([]byte, sysfsBusIdMax) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read UnExportRequest from reader: %w", err) + } + + copy(i.busID[:], buf) + return nil +} + +type UnExportReply struct { + OpCommon +} + +func NewUnExportReply(status OpStatus) *UnExportReply { + return &UnExportReply{ + OpCommon: OpCommon{ + Version: Version, + Code: OpRepExport, + Status: status, + }, + } +} + +func (i *UnExportReply) Encode(w io.Writer) error { + if err := i.OpCommon.Encode(w); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + return nil +} + +func (i *UnExportReply) Decode(r io.Reader) error { + if err := i.OpCommon.Decode(r); err != nil { + return fmt.Errorf("failed to decode OpCommon: %w", err) + } + return nil +} diff --git a/images/virtualization-dra/pkg/usbip/sysfs.go b/images/virtualization-dra/pkg/usbip/sysfs.go new file mode 100644 index 0000000000..ed2fe24813 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/sysfs.go @@ -0,0 +1,79 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "fmt" + "os" +) + +type sysfsAttr interface { + Complete() string +} + +func writeSysfsAttr(attrPath string, value sysfsAttr) error { + f, err := os.OpenFile(attrPath, os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(value.Complete()) + return err +} + +const ( + bindAttrPathTmpl = "/sys/bus/usb/drivers/%s/bind" + unbindAttrPathTmpl = "/sys/bus/usb/drivers/%s/unbind" + rebindAttrPathTmpl = "/sys/bus/usb/drivers/%s/rebind" + matchBusIDAttrPathTmpl = "/sys/bus/usb/drivers/%s/match_busid" + + usbDevicesTmpl = "/sys/bus/usb/devices/%s" + + usbipStatusPathTmpl = "/sys/bus/usb/devices/%s/usbip_status" + usbipSockFdPathTmpl = "/sys/bus/usb/devices/%s/usbip_sockfd" + + usbipHostDriverName = "usbip-host" +) + +func getUSBDevicePath(busID string) string { + return fmt.Sprintf(usbDevicesTmpl, busID) +} + +func bindAttrPath(driver string) string { + return fmt.Sprintf(bindAttrPathTmpl, driver) +} + +func unbindAttrPath(driver string) string { + return fmt.Sprintf(unbindAttrPathTmpl, driver) +} + +func rebindAttrPath(driver string) string { + return fmt.Sprintf(rebindAttrPathTmpl, driver) +} + +func matchBusIDAttrPath(driver string) string { + return fmt.Sprintf(matchBusIDAttrPathTmpl, driver) +} + +func usbipStatusPath(busID string) string { + return fmt.Sprintf(usbipStatusPathTmpl, busID) +} + +func usbipSockFdPath(busID string) string { + return fmt.Sprintf(usbipSockFdPathTmpl, busID) +} diff --git a/images/virtualization-dra/pkg/usbip/usbipd.go b/images/virtualization-dra/pkg/usbip/usbipd.go new file mode 100644 index 0000000000..9abecd154e --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/usbipd.go @@ -0,0 +1,546 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/deckhouse/virtualization-dra/pkg/libusb" + "github.com/deckhouse/virtualization-dra/pkg/usbip/protocol" +) + +const ( + defaultMaxTCPConnection = 100 + defaultGracefulShutdownTimeout = 30 * time.Second +) + +type Option func(usbipd *USBIPD) + +func WithGracefulShutdownTimeout(timeout time.Duration) Option { + return func(u *USBIPD) { + u.gracefulShutdownTimeout = timeout + } +} + +func WithMaxTCPConnection(maxTCPConnection int) Option { + return func(u *USBIPD) { + u.maxTCPConnection = maxTCPConnection + } +} + +func WithExport(enabled bool) Option { + return func(usbipd *USBIPD) { + usbipd.exportEnabled = enabled + } +} + +func NewUSBIPD(addr string, monitor libusb.Monitor, opts ...Option) *USBIPD { + usbipd := &USBIPD{ + addr: addr, + monitor: monitor, + gracefulShutdownTimeout: defaultGracefulShutdownTimeout, + maxTCPConnection: defaultMaxTCPConnection, + logger: slog.Default().With(slog.String("component", "usbipd")), + quit: make(chan struct{}), + } + + for _, opt := range opts { + opt(usbipd) + } + + if usbipd.exportEnabled { + usbipd.usbBinder = NewUSBBinder() + } + + return usbipd +} + +type USBIPD struct { + addr string + monitor libusb.Monitor + gracefulShutdownTimeout time.Duration + maxTCPConnection int + logger *slog.Logger + exportEnabled bool + usbBinder USBBinder + + listener net.Listener + connWg sync.WaitGroup + connCount atomic.Int64 + quit chan struct{} +} + +func (u *USBIPD) Start(ctx context.Context) error { + if err := u.setup(); err != nil { + return err + } + + go func() { + <-ctx.Done() + close(u.quit) + if u.listener != nil { + if err := u.listener.Close(); err != nil { + u.logger.Error("failed to close listener", slog.Any("error", err)) + } + } + }() + + u.connWg.Add(1) + go u.run(ctx) + + return nil +} + +func (u *USBIPD) Run(ctx context.Context) error { + if err := u.setup(); err != nil { + return err + } + + go func() { + <-ctx.Done() + close(u.quit) + if u.listener != nil { + if err := u.listener.Close(); err != nil { + u.logger.Error("failed to close listener", slog.Any("error", err)) + } + } + }() + + u.connWg.Add(1) + u.run(ctx) + + if waitWithTimeout(&u.connWg, u.gracefulShutdownTimeout) { + u.logger.Info("all connections closed") + } else { + u.logger.Warn("graceful shutdown timeout, some connections may be left open") + } + + return nil +} + +// waitWithTimeout waits for wg to complete; returns true if done before timeout. +func waitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +func (u *USBIPD) setup() (err error) { + u.listener, err = net.Listen("tcp", u.addr) + return err +} + +func (u *USBIPD) run(ctx context.Context) { + defer u.connWg.Done() + for { + conn, err := u.listener.Accept() + // Error occurred when + // 1. Connection error + // 2. The listener is closed (e.g. on context cancellation) + if err != nil { + select { + case <-u.quit: + return + default: + u.logger.Error("unable to accept request", slog.String("address", u.addr), slog.Any("err", err)) + } + continue + } + { + // Check if TCP connection reached the limit specified in given config + count := u.connCount.Load() + if count+1 > int64(u.maxTCPConnection) { + u.logger.Error("maximum TCP connection reached, drop the connection", slog.Int64("count", count)) + if err := conn.Close(); err != nil { + u.logger.Error("failed to close connection", slog.String("error", err.Error())) + } + continue + } + + // TCP connection handler + u.connWg.Add(1) + u.connCount.Add(1) + go func() { + defer u.connCount.Add(-1) + defer u.connWg.Done() + defer func() { + if err := conn.Close(); err != nil { + u.logger.Error("failed to close connection", slog.String("error", err.Error())) + } + }() + + u.logger.Info("new connection established", slog.String("addr", conn.RemoteAddr().String())) + keepConn, err := u.handleConnection(conn) + if err != nil { + if !errors.Is(err, io.EOF) { + u.logger.Error("failed to handle connection", slog.Any("err", err), slog.String("addr", conn.RemoteAddr().String())) + } else { + u.logger.Info("connection EOF", slog.String("addr", conn.RemoteAddr().String())) + } + } + if keepConn { + // don't handle and read from the socket. other work doing a kernel module + <-ctx.Done() + } + u.logger.Info("connection closed", slog.String("addr", conn.RemoteAddr().String())) + }() + } + } +} + +// https://docs.kernel.org/usb/usbip_protocol.html +// https://github.com/torvalds/linux/blob/9448598b22c50c8a5bb77a9103e2d49f134c9578/tools/usb/usbip/src/usbipd.c#L251 +func (u *USBIPD) handleConnection(conn net.Conn) (bool, error) { + opCommon := &protocol.OpCommon{} + if err := opCommon.Decode(conn); err != nil { + return false, fmt.Errorf("failed to decode OpCommon: %w", err) + } + + if opCommon.Version != protocol.Version { + return false, fmt.Errorf("unsupported USBIP version: %d", opCommon.Version) + } + + if opCommon.Status != protocol.OpStatusOk { + return false, fmt.Errorf("request failed: %s", opCommon.Status.String()) + } + + switch opCommon.Code { + case protocol.OpReqDevList: + if err := u.handleDeviceList(conn); err != nil { + return false, fmt.Errorf("failed to handle OpReqDevList: %w", err) + } + case protocol.OpReqImport: + if err := u.handleImportRequest(conn); err != nil { + return false, fmt.Errorf("failed to handle OpReqImport: %w", err) + } + return true, nil + case protocol.OpReqExport: + if err := u.handleExportRequest(conn); err != nil { + return false, fmt.Errorf("failed to handle OpRepExport: %w", err) + } + case protocol.OpReqUnexport: + if err := u.handleUnexportRequest(conn); err != nil { + return false, fmt.Errorf("failed to handle OpReqUnexport: %w", err) + } + case protocol.OpReqDevInfo, protocol.OpReqCrypkey: + // nothing to do + default: + return false, fmt.Errorf("unsupported OpCommon.Code: %d", opCommon.Code) + } + + return false, nil +} + +// https://github.com/torvalds/linux/blob/9448598b22c50c8a5bb77a9103e2d49f134c9578/tools/usb/usbip/src/usbipd.c#L229 +func (u *USBIPD) handleDeviceList(conn net.Conn) error { + info := u.getUSBDeviceInfo() + if len(info) == 0 { + slog.Info("no USB devices found") + } + devList := protocol.NewDeviceList(protocol.OpStatusOk, info) + return devList.Encode(conn) +} + +// https://github.com/torvalds/linux/blob/9448598b22c50c8a5bb77a9103e2d49f134c9578/tools/usb/usbip/src/usbipd.c#L91 +func (u *USBIPD) handleImportRequest(conn net.Conn) error { + importReq := &protocol.ImportRequest{} + if err := importReq.Decode(conn); err != nil { + return fmt.Errorf("failed to decode ImportRequest: %w", err) + } + + busID := importReq.BusID() + log := u.logger.With(slog.String("busID", busID)) + log.Info("import request") + + bindDevice, exists := u.monitor.GetDeviceByBusID(busID) + if !exists { + log.Info("USB device is not found") + return protocol.NewImportReply(protocol.OpStatusNoDev, protocol.USBDevice{}).Encode(conn) + } + + // should set TCP_NODELAY for usbip + u.setNoDelay(conn) + + status := u.exportDevice(conn, bindDevice) + if status != protocol.OpStatusOk { + log.Error("failed to export device", slog.String("status", status.String())) + } else { + u.logger.Info("device exported", slog.Any("device", bindDevice)) + } + + usbDevice := toUSBDeviceInfo(bindDevice).USBDevice + + return protocol.NewImportReply(status, usbDevice).Encode(conn) +} + +// https://github.com/torvalds/linux/blob/9448598b22c50c8a5bb77a9103e2d49f134c9578/tools/usb/usbip/libsrc/usbip_host_common.c#L212 +func (u *USBIPD) exportDevice(conn net.Conn, device *libusb.USBDevice) protocol.OpStatus { + log := u.logger.With(slog.String("busID", device.BusID)) + log.Info("export request") + + usbIpStatus, err := u.getUSBIPStatus(device) + if err != nil { + log.Error("failed to get USBIP status", slog.Any("error", err)) + return protocol.OpStatusError + } + + if usbIpStatus != protocol.DeviceStatusAvailable { + log.Info("USBIP status is not available") + switch usbIpStatus { + case protocol.DeviceStatusError: + log.Debug("USBIP status is error") + return protocol.OpStatusDevErr + case protocol.DeviceStatusUsed: + log.Debug("USBIP status is used") + return protocol.OpStatusDevBusy + default: + log.Debug("USBIP status unknown") + return protocol.OpStatusNA + } + } + + syscallConn, ok := conn.(syscall.Conn) + if !ok { + log.Error("conn does not implement syscall.Conn") + return protocol.OpStatusNA + } + + var sockFd int + rawConn, err := syscallConn.SyscallConn() + if err != nil { + log.Error("failed to get raw connection", slog.Any("error", err)) + return protocol.OpStatusNA + } + err = rawConn.Control(func(fd uintptr) { + sockFd = int(fd) + }) + if err != nil { + log.Error("failed to get socket fd", slog.Any("error", err)) + return protocol.OpStatusNA + } + + err = writeSysfsAttr(usbipSockFdPath(device.BusID), sockFdAttr{sockFd: sockFd}) + if err != nil { + log.Error("failed to write usbip_sockfd", slog.Any("error", err)) + return protocol.OpStatusNA + } + + log.Info("Connect") + + return protocol.OpStatusOk +} + +func (u *USBIPD) handleExportRequest(conn net.Conn) error { + if !u.exportEnabled { + u.logger.Info("USBIPD export is disabled, skip handle export request") + return nil + } + + exportRequest := &protocol.ExportRequest{} + if err := exportRequest.Decode(conn); err != nil { + return fmt.Errorf("failed to decode ExportRequest: %w", err) + } + + busID := exportRequest.BusID() + log := u.logger.With(slog.String("busID", busID)) + log.Info("export request") + + _, exists := u.monitor.GetDeviceByBusID(busID) + if !exists { + log.Info("USB device is not found") + return protocol.NewExportReply(protocol.OpStatusNoDev).Encode(conn) + } + + bound, err := u.usbBinder.IsBound(busID) + if err != nil { + log.Error("failed to check if USB device is bound", slog.Any("error", err)) + return protocol.NewExportReply(protocol.OpStatusError).Encode(conn) + } + + if bound { + log.Info("USB device is already bound") + return protocol.NewExportReply(protocol.OpStatusOk).Encode(conn) + } + + err = u.usbBinder.Bind(busID) + if err != nil { + log.Error("failed to bind USB device", slog.Any("error", err)) + return protocol.NewExportReply(protocol.OpStatusError).Encode(conn) + } + + log.Info("USB device bound") + return protocol.NewExportReply(protocol.OpStatusOk).Encode(conn) +} + +func (u *USBIPD) handleUnexportRequest(conn net.Conn) error { + if !u.exportEnabled { + u.logger.Info("USBIPD export is disabled, skip handle unexport request") + return nil + } + + unexportRequest := &protocol.UnExportRequest{} + if err := unexportRequest.Decode(conn); err != nil { + return fmt.Errorf("failed to decode UnExportRequest: %w", err) + } + + busID := unexportRequest.BusID() + log := u.logger.With(slog.String("busID", busID)) + log.Info("unexport request") + + _, exists := u.monitor.GetDeviceByBusID(busID) + if !exists { + log.Info("USB device is not found") + return protocol.NewUnExportReply(protocol.OpStatusNoDev).Encode(conn) + } + + bound, err := u.usbBinder.IsBound(busID) + if err != nil { + log.Error("failed to check if USB device is bound", slog.Any("error", err)) + return protocol.NewUnExportReply(protocol.OpStatusError).Encode(conn) + } + + if !bound { + log.Info("USB device already unbound") + return protocol.NewUnExportReply(protocol.OpStatusOk).Encode(conn) + } + + err = u.usbBinder.Unbind(busID) + if err != nil { + log.Error("failed to unbind USB device", slog.Any("error", err)) + return protocol.NewUnExportReply(protocol.OpStatusError).Encode(conn) + } + + log.Info("USB device unbound") + return protocol.NewUnExportReply(protocol.OpStatusOk).Encode(conn) +} + +type sockFdAttr struct { + sockFd int +} + +func (a sockFdAttr) Complete() string { + return fmt.Sprintf("%d\n", a.sockFd) +} + +func (u *USBIPD) getUSBIPStatus(device *libusb.USBDevice) (protocol.DeviceStatus, error) { + statusPath := usbipStatusPath(device.BusID) + + data, err := os.ReadFile(statusPath) + if err != nil { + return 0, fmt.Errorf("failed to read %s: %w", statusPath, err) + } + + statusStr := strings.TrimSpace(string(data)) + + value, err := strconv.ParseUint(statusStr, 10, 32) + if err != nil { + return 0, fmt.Errorf("invalid status value %q: %w", statusStr, err) + } + + status := protocol.DeviceStatus(value) + + return status, nil +} + +func (u *USBIPD) setNoDelay(conn net.Conn) { + tcpConn, ok := conn.(*net.TCPConn) + if ok { + err := tcpConn.SetNoDelay(true) + if err != nil { + u.logger.Error("failed to set TCP_NODELAY", slog.String("error", err.Error())) + } + return + } + u.logger.Error("failed to cast connection to TCPConn") +} + +// TODO: check already used devices +func (u *USBIPD) getUSBDeviceInfo() []protocol.USBDeviceInfo { + devices := u.monitor.GetDevices() + + var bindDevices []protocol.USBDeviceInfo + + for _, device := range devices { + if device.Driver == usbipHostDriverName { + bindDevice := toUSBDeviceInfo(&device) + bindDevices = append(bindDevices, bindDevice) + } + } + + return bindDevices +} + +func toUSBDeviceInfo(device *libusb.USBDevice) protocol.USBDeviceInfo { + if device == nil { + return protocol.USBDeviceInfo{} + } + return protocol.USBDeviceInfo{ + USBDevice: protocol.USBDevice{ + Path: protocol.ToDevicePath(device.DevicePath), + BusID: protocol.ToBusID(device.BusID), + Busnum: device.Bus, + Devnum: device.DeviceNumber, + Speed: toSpeed(device.Speed), + IDVendor: device.VendorID, + IDProduct: device.ProductID, + BcdDevice: device.BCD, + BDeviceClass: device.BDeviceClass, + BDeviceSubClass: device.BDeviceSubClass, + BDeviceProtocol: device.BDeviceProtocol, + BConfigurationValue: device.BConfigurationValue, + BNumConfigurations: device.BNumConfigurations, + BNumInterfaces: device.BNumInterfaces, + }, + Interfaces: toInterfaces(device.Interfaces), + } +} + +func toInterfaces(interfaces []libusb.USBDeviceInterface) []protocol.USBDeviceInterface { + result := make([]protocol.USBDeviceInterface, len(interfaces)) + for i, iface := range interfaces { + result[i] = protocol.USBDeviceInterface{ + BInterfaceClass: iface.BInterfaceClass, + BInterfaceSubClass: iface.BInterfaceSubClass, + BInterfaceProtocol: iface.BInterfaceProtocol, + } + } + return result +} + +func toSpeed(speed uint32) uint32 { + return uint32(libusb.ResolveDeviceSpeed(speed)) +} diff --git a/images/virtualization-dra/pkg/usbip/usbipd_config.go b/images/virtualization-dra/pkg/usbip/usbipd_config.go new file mode 100644 index 0000000000..1304b9bcb4 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/usbipd_config.go @@ -0,0 +1,61 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "net" + "os" + "strconv" + "time" + + "github.com/spf13/pflag" + + "github.com/deckhouse/virtualization-dra/pkg/libusb" +) + +type USBIPDConfig struct { + Address string + Port int + GracefulShutdownTimeout time.Duration + MaxTcpConnections int + ExportEnabled bool +} + +func (c *USBIPDConfig) AddFlags(fs *pflag.FlagSet) { + fs.StringVar(&c.Address, "usbipd-address", os.Getenv("USBIPD_ADDRESS"), "USBIPD address") + fs.IntVar(&c.Port, "usbipd-port", 3240, "USBIPD port") + fs.DurationVar(&c.GracefulShutdownTimeout, "usbipd-graceful-shutdown-timeout", 0, "USBIPD graceful shutdown timeout") + fs.IntVar(&c.MaxTcpConnections, "usbipd-max-tcp-connections", 0, "USBIPD max TCP connections") + fs.BoolVar(&c.ExportEnabled, "usbipd-export-enabled", false, "USBIPD export enabled") +} + +func (c *USBIPDConfig) Complete(monitor libusb.Monitor) (*USBIPD, error) { + var opts []Option + if c.GracefulShutdownTimeout > 0 { + opts = append(opts, WithGracefulShutdownTimeout(c.GracefulShutdownTimeout)) + } + if c.MaxTcpConnections > 0 { + opts = append(opts, WithMaxTCPConnection(c.MaxTcpConnections)) + } + if c.ExportEnabled { + opts = append(opts, WithExport(true)) + } + + address := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + + return NewUSBIPD(address, monitor, opts...), nil +} diff --git a/images/virtualization-dra/pkg/usbip/vhci.go b/images/virtualization-dra/pkg/usbip/vhci.go new file mode 100644 index 0000000000..7fbb048867 --- /dev/null +++ b/images/virtualization-dra/pkg/usbip/vhci.go @@ -0,0 +1,213 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package usbip + +import ( + "bytes" + "fmt" + "os" + "strconv" + "strings" +) + +const ( + vhciStatePath = "/var/run/vhci_hcd" + platformPath = "/sys/devices/platform" + usbipVhciHcdNPortsPath = "/sys/devices/platform/vhci_hcd.0/nports" + + vhciHcdAttach = "/sys/devices/platform/vhci_hcd.0/attach" + vhciHcdDetach = "/sys/devices/platform/vhci_hcd.0/detach" + vhciHcdStatus = "/sys/devices/platform/vhci_hcd.0/status" + secondaryVhciHcdStatusTmpl = "/sys/devices/platform/vhci_hcd.%d/status.%d" + + vhciStatePortTmpl = "/var/run/vhci_hcd/port%d" +) + +func vhciStatePortPath(port int) string { + return fmt.Sprintf(vhciStatePortTmpl, port) +} + +func secondaryVhciHcdStatusPath(count int) string { + return fmt.Sprintf(secondaryVhciHcdStatusTmpl, count, count) +} + +type vhciDriver struct { + nports int + ncontrollers int + idevs []importDevice +} + +type importDevice struct { + hub hubSpeed + port, status, devID, busnum, devnum int + localBusID string +} + +type hubSpeed int + +const ( + hubSpeedHigh hubSpeed = iota + hubSpeedSuper +) + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/libsrc/vhci_driver.c#L243 +func newVhciDriver() (*vhciDriver, error) { + nports, err := getNPorts() + if err != nil { + return nil, err + } + ncontrollers, err := getNControllers() + if err != nil { + return nil, err + } + + driver := &vhciDriver{ + nports: nports, + ncontrollers: ncontrollers, + } + + err = driver.refreshImportDeviceList() + if err != nil { + return nil, fmt.Errorf("failed to refresh import device list: %w", err) + } + + return driver, nil +} + +func getNPorts() (int, error) { + data, err := os.ReadFile(usbipVhciHcdNPortsPath) + if err != nil { + return -1, err + } + + nports, err := strconv.Atoi(strings.TrimSpace(string(data))) + if err != nil { + return -1, err + } + + return nports, nil +} + +func getNControllers() (int, error) { + entries, err := os.ReadDir(platformPath) + if err != nil { + return -1, err + } + count := 0 + for _, entry := range entries { + if entry.IsDir() && strings.HasPrefix(entry.Name(), "vhci_hcd") { + count++ + } + } + + return count, nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/libsrc/vhci_driver.c#L111 +func (d *vhciDriver) refreshImportDeviceList() error { + status := vhciHcdStatus + + for i := 0; i < d.ncontrollers; i++ { + if i > 0 { + status = secondaryVhciHcdStatusPath(i) + } + + attrStatus, err := os.ReadFile(status) + if err != nil { + return fmt.Errorf("failed to read %s: %w", status, err) + } + + err = d.parseStatus(attrStatus) + if err != nil { + return fmt.Errorf("failed to parse attr status %s: %w", status, err) + } + } + + return nil +} + +// https://github.com/torvalds/linux/blob/b927546677c876e26eba308550207c2ddf812a43/tools/usb/usbip/libsrc/vhci_driver.c#L40 +func (d *vhciDriver) parseStatus(statusBytes []byte) error { + lines := strings.Split(string(statusBytes), "\n") + + // hub port sta spd dev sockfd local_busid + // hs 0000 004 000 00000000 000000 0-0 + // hs 0001 004 000 00000000 000000 0-0 + // hs 0002 004 000 00000000 000000 0-0 + + head := true + for _, line := range lines { + if head { + // skip header + head = false + continue + } + + if strings.TrimSpace(line) == "" { + continue + } + + var ( + hub string + port, status, speed, devID, sockFd int + localBusID string + ) + + buf := bytes.NewBufferString(line) + _, err := fmt.Fscanf(buf, "%2s %d %d %d %x %d %31s", &hub, &port, &status, &speed, &devID, &sockFd, &localBusID) + if err != nil { + return fmt.Errorf("failed to parse status: %w", err) + } + + if len(d.idevs) <= port { + idevs := make([]importDevice, port+1) + copy(idevs, d.idevs) + d.idevs = idevs + } + + busnum, devnum := getBusNumDevNum(devID) + + idev := &d.idevs[port] + + idev.port = port + idev.status = status + idev.devID = devID + idev.busnum = busnum + idev.devnum = devnum + idev.localBusID = localBusID + + switch hub { + case "hs": + idev.hub = hubSpeedHigh + case "ss": + idev.hub = hubSpeedSuper + } + } + + return nil +} + +func getDevId(busnum, devnum uint32) int { + return int((busnum << 16) | devnum) +} + +func getBusNumDevNum(devID int) (int, int) { + busnum := devID >> 16 + devnum := devID & 0x0000ffff + + return busnum, devnum +} diff --git a/images/virtualization-dra/test/pod-with-template-4.yaml b/images/virtualization-dra/test/pod-with-template-4.yaml new file mode 100644 index 0000000000..400412ef7a --- /dev/null +++ b/images/virtualization-dra/test/pod-with-template-4.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: Pod +metadata: + name: test-pod-with-usb-template-4 + namespace: usb +spec: + containers: + - name: test-container + image: nicolaka/netshoot:latest + command: ["sleep", "3600"] + resources: + claims: + - name: usb-device + resourceClaims: + - name: usb-device + resourceClaimTemplateName: usb-product-0951-vendor-0104-template diff --git a/images/virtualization-dra/test/resourceclaim-template-2.yaml b/images/virtualization-dra/test/resourceclaim-template-2.yaml index ee66b7270c..7f4a85b237 100644 --- a/images/virtualization-dra/test/resourceclaim-template-2.yaml +++ b/images/virtualization-dra/test/resourceclaim-template-2.yaml @@ -1,17 +1,22 @@ -apiVersion: resource.k8s.io/v1beta1 +apiVersion: resource.k8s.io/v1 kind: ResourceClaimTemplate metadata: name: usb-product-0951-vendor-0104-template spec: + metadata: + annotations: + "usb.virtualization.deckhouse.io/device-user": "107" + "usb.virtualization.deckhouse.io/device-group": "107" spec: devices: requests: - name: req-0 - allocationMode: "ExactCount" - count: 1 - deviceClassName: usb-devices.virtualization.deckhouse.io - selectors: - - cel: - expression: |- - device.attributes["virtualization-usb"].productID == "0104" && - device.attributes["virtualization-usb"].vendorID == "0951" + exactly: + allocationMode: "ExactCount" + count: 1 + deviceClassName: usb-devices.virtualization.deckhouse.io + selectors: + - cel: + expression: |- + device.attributes["virtualization-usb"].productID == "0104" && + device.attributes["virtualization-usb"].vendorID == "0951" diff --git a/images/virtualization-dra/werf.inc.yaml b/images/virtualization-dra/werf.inc.yaml index c8b48367e7..1195ee3573 100644 --- a/images/virtualization-dra/werf.inc.yaml +++ b/images/virtualization-dra/werf.inc.yaml @@ -38,3 +38,8 @@ shell: {{- else }} {{- include "image-build.build" (set $ "BuildCommand" `go build -ldflags="-s -w" -v -o /out/virtualization-dra-usb ./cmd/usb/dra`) | nindent 6 }} {{- end }} + + - | + echo "Build go-usbip binary" + {{- $_ := set $ "ProjectName" (list $.ImageName "go-usbip" | join "/") }} + {{- include "image-build.build" (set $ "BuildCommand" `go build -ldflags="-s -w" -v -o /out/go-usbip ./cmd/usb/go-usbip`) | nindent 6 }} diff --git a/module.yaml b/module.yaml index 218f653b9a..ceb45a8d50 100644 --- a/module.yaml +++ b/module.yaml @@ -8,7 +8,7 @@ descriptions: ru: Запускает и управляет виртуальными машинами в платформе Deckhouse. tags: ["virtualization"] requirements: - deckhouse: ">= 1.74.2" + deckhouse: ">= 1.74" # TODO: FOR TEST. REVERT ME ">= 1.74.2" modules: cni-cilium: ">= 0.0.0" disable: diff --git a/templates/virtualization-dra/_helper.tpl b/templates/virtualization-dra/_helper.tpl index 711359d8bb..19ab94ec10 100644 --- a/templates/virtualization-dra/_helper.tpl +++ b/templates/virtualization-dra/_helper.tpl @@ -1,5 +1,7 @@ {{- define "virtualization-dra.isEnabled" -}} {{- if eq (include "hasValidModuleConfig" .) "true" -}} -false +{{- if semverCompare ">=1.34" .Values.global.discovery.kubernetesVersion -}} +true +{{- end -}} {{- end -}} {{- end -}} diff --git a/templates/virtualization-dra/daemonset.yaml b/templates/virtualization-dra/daemonset.yaml index 71526466b7..acbc0b4f1b 100644 --- a/templates/virtualization-dra/daemonset.yaml +++ b/templates/virtualization-dra/daemonset.yaml @@ -1,5 +1,10 @@ {{- $priorityClassName := include "priorityClassName" . }} {{- $delve := (include "delve" . | fromYaml) -}} +{{- define "virtualization-dra_init_resources" }} +cpu: 10m +memory: 25Mi +{{- end }} + {{- define "virtualization-dra_resources" }} cpu: 10m memory: 25Mi @@ -22,7 +27,7 @@ spec: kind: DaemonSet name: virtualization-dra updatePolicy: - updateMode: "Auto" + updateMode: {{ include "vpa.policyUpdateMode" . }} resourcePolicy: containerPolicies: - containerName: virtualization-dra @@ -56,14 +61,31 @@ spec: - name: virtualization-module-registry serviceAccountName: virtualization-dra dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true nodeSelector: kubernetes.io/os: linux - hostNetwork: true + initContainers: + - name: virtualization-dra-init + image: {{ include "helm_lib_module_image" (list . "virtualizationDraUsb") }} + args: + - init + {{- include "helm_lib_module_container_security_context_read_only_root_filesystem_capabilities_drop_all_and_add" (list . (list "SYS_MODULE")) | nindent 10 }} + resources: + requests: + {{- include "helm_lib_module_ephemeral_storage_only_logs" . | nindent 14 }} + {{- include "virtualization-dra_init_resources" . | nindent 14 }} + volumeMounts: + - mountPath: /lib/modules + name: lib-modules + - mountPath: /tmp + name: tmp containers: - name: virtualization-dra {{- include "helm_lib_module_container_security_context_privileged_read_only_root_filesystem" . | nindent 10 }} image: {{ include "helm_lib_module_image" (list . "virtualizationDraUsb") }} imagePullPolicy: "IfNotPresent" + args: + - --feature-gates=USBGateway=true,USBNodeLocalMultiAllocation=true env: - name: NAMESPACE valueFrom: @@ -116,6 +138,11 @@ spec: - name: var-run mountPath: /var/run volumes: + - name: tmp + emptyDir: {} + - name: lib-modules + hostPath: + path: /lib/modules - name: plugins-registry hostPath: path: /var/lib/kubelet/plugins_registry diff --git a/templates/virtualization-dra/deviceclass.yaml b/templates/virtualization-dra/deviceclass.yaml index 6805a542c1..b3e694d7f5 100644 --- a/templates/virtualization-dra/deviceclass.yaml +++ b/templates/virtualization-dra/deviceclass.yaml @@ -1,5 +1,5 @@ {{- if eq (include "virtualization-dra.isEnabled" .) "true"}} -apiVersion: resource.k8s.io/v1beta1 +apiVersion: resource.k8s.io/v1 kind: DeviceClass metadata: name: usb-devices.virtualization.deckhouse.io diff --git a/templates/virtualization-dra/nodegroupconfiguration-usbip.yaml b/templates/virtualization-dra/nodegroupconfiguration-usbip.yaml new file mode 100644 index 0000000000..9b36297270 --- /dev/null +++ b/templates/virtualization-dra/nodegroupconfiguration-usbip.yaml @@ -0,0 +1,52 @@ +{{- if eq (include "virtualization-dra.isEnabled" .) "true"}} +apiVersion: deckhouse.io/v1alpha1 +kind: NodeGroupConfiguration +metadata: + name: virtualization-install-usbip-modules + {{- include "helm_lib_module_labels" (list .) | nindent 2 }} +spec: + bundles: + - '*' + nodeGroups: + - '*' + weight: 30 + content: | + # Copyright 2026 Flant JSC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + pkg="" + bundle="$(bb-is-bundle)" + + if bb-is-distro-like? "debian"; then + # Debian, Ubuntu, Astra (Debian-based): modules in linux-modules-extra-$(uname -r) + pkg="linux-modules-extra-$(uname -r)" + elif bb-is-distro-like? "rhel" || bb-is-distro-like? "centos" || bb-is-distro-like? "fedora"; then + # RHEL, CentOS, Fedora, ROSA, REDOS: kernel-modules-extra (match by ID or ID_LIKE) + pkg="kernel-modules-extra" + elif [[ "$bundle" == "altlinux" ]]; then + # Alt Linux + pkg="kernel-modules-extra" + else + bb-log-warn "Unsupported OS for usbip kernel modules: bundle=${bundle}. Skipping." + return 1 + fi + + if [[ -z "$pkg" ]]; then + bb-log-warn "Could not determine package for usbip kernel modules." + return 1 + fi + + bb-log-info "Installing package for usbip kernel modules (usbip_core, usbip_host, vhci_hcd): ${pkg}" + bb-pkg install "$pkg" +{{- end }} diff --git a/templates/virtualization-dra/rbac-for-us.yaml b/templates/virtualization-dra/rbac-for-us.yaml index 8f1405f470..37a9f6259c 100644 --- a/templates/virtualization-dra/rbac-for-us.yaml +++ b/templates/virtualization-dra/rbac-for-us.yaml @@ -15,13 +15,16 @@ metadata: rules: - apiGroups: ["resource.k8s.io"] resources: ["resourceclaims"] - verbs: ["get"] + verbs: ["get", "update", "patch"] - apiGroups: [""] resources: ["nodes"] verbs: ["get"] - apiGroups: ["resource.k8s.io"] resources: ["resourceslices"] verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] + - apiGroups: [""] + resources: ["secrets"] + verbs: ["get", "list", "watch", "create", "update", "patch"] --- kind: ClusterRoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/werf.yaml b/werf.yaml index f880307915..f8a8c8481d 100644 --- a/werf.yaml +++ b/werf.yaml @@ -121,6 +121,7 @@ git: - openapi/openapi-case-tests.yaml {{- if eq .MODULE_EDITION "CE" }} - templates/virtualization-audit + - templates/virtualization-dra {{- end }} shell: install: