diff --git a/deploy/rbac-csi-azurefile-node.yaml b/deploy/rbac-csi-azurefile-node.yaml index 61f0f0c8a4..72c87e5470 100644 --- a/deploy/rbac-csi-azurefile-node.yaml +++ b/deploy/rbac-csi-azurefile-node.yaml @@ -54,3 +54,24 @@ roleRef: name: csi-azurefile-node-katacc-role apiGroup: rbac.authorization.k8s.io --- +kind: ClusterRole +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: csi-azurefile-node-role +rules: + - apiGroups: [""] + resources: ["nodes"] + verbs: ["get"] +--- +kind: ClusterRoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: csi-azurefile-node-binding +subjects: + - kind: ServiceAccount + name: csi-azurefile-node-sa + namespace: kube-system +roleRef: + kind: ClusterRole + name: csi-azurefile-node-role + apiGroup: rbac.authorization.k8s.io diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index e79f617ad4..99608772a7 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -48,6 +48,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" + clientset "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" mount "k8s.io/mount-utils" @@ -453,7 +454,12 @@ func (d *Driver) Run(ctx context.Context) error { csi.RegisterControllerServer(server, d) csi.RegisterNodeServer(server, d) d.server = server + val, val2, err := getNodeInfoFromLabels(ctx, d.NodeID, d.kubeClient) + if err != nil { + klog.Warningf("failed to get node info from labels: %v", err) + } + klog.V(2).Infof("Node info from labels: %s, %s", val, val2) listener, err := csicommon.ListenEndpoint(d.endpoint) if err != nil { klog.Fatalf("failed to listen endpoint: %v", err) @@ -1274,3 +1280,28 @@ func (d *Driver) getStorageEndPointSuffix() string { } return d.cloud.Environment.StorageEndpointSuffix } + +func getNodeInfoFromLabels(ctx context.Context, nodeId string, kubeClient clientset.Interface) (string, string, error) { + if kubeClient == nil || kubeClient.CoreV1() == nil { + return "", "", fmt.Errorf("kubeClient is nil") + } + + node, err := kubeClient.CoreV1().Nodes().Get(ctx, nodeId, metav1.GetOptions{}) + if err != nil { + return "", "", fmt.Errorf("get node(%s) failed with %v", nodeId, err) + } + + if len(node.Labels) == 0 { + return "", "", fmt.Errorf("node(%s) label is empty", nodeId) + } + return node.Labels["kubernetes.azure.com/kata-mshv-vm-isolation"], node.Labels["katacontainers.io/kata-runtime"], nil +} + +func isNodeConfidential(ctx context.Context, nodeId string, kubeClient clientset.Interface) bool { + val, val2, err := getNodeInfoFromLabels(ctx, nodeId, kubeClient) + if err != nil { + klog.Warningf("get node(%s) confidential label failed with %v", nodeId, err) + return false + } + return val == "true" || val2 == "true" +} diff --git a/pkg/azurefile/nodeserver.go b/pkg/azurefile/nodeserver.go index c961da8188..10985d1d68 100644 --- a/pkg/azurefile/nodeserver.go +++ b/pkg/azurefile/nodeserver.go @@ -101,8 +101,8 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu } if d.enableKataCCMount { - enableKataCCMount := getValueInMap(context, enableKataCCMountField) - if strings.EqualFold(enableKataCCMount, trueValue) && context[podNameField] != "" && context[podNamespaceField] != "" { + enableKataCCMount := isNodeConfidential(ctx, d.NodeID, d.kubeClient) + if enableKataCCMount && context[podNameField] != "" && context[podNamespaceField] != "" { runtimeClass, err := getRuntimeClassForPodFunc(ctx, d.kubeClient, context[podNameField], context[podNamespaceField]) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get runtime class for pod %s/%s: %v", context[podNamespaceField], context[podNameField], err) @@ -252,7 +252,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe // don't respect fsType from req.GetVolumeCapability().GetMount().GetFsType() // since it's ext4 by default on Linux var fsType, server, protocol, ephemeralVolMountOptions, storageEndpointSuffix, folderName string - var ephemeralVol, enableKataCCMount bool + var ephemeralVol bool fileShareNameReplaceMap := map[string]string{} mountPermissions := d.mountPermissions @@ -284,11 +284,6 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe fileShareNameReplaceMap[pvcNameMetadata] = v case pvNameKey: fileShareNameReplaceMap[pvNameMetadata] = v - case enableKataCCMountField: - enableKataCCMount, err = strconv.ParseBool(v) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid %s: %s in storage class", enableKataCCMountField, v) - } case mountPermissionsField: if v != "" { var err error @@ -423,7 +418,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } klog.V(2).Infof("volume(%s) mount %s on %s succeeded", volumeID, source, cifsMountPath) } - + enableKataCCMount := isNodeConfidential(ctx, d.NodeID, d.kubeClient) // If runtime OS is not windows and protocol is not nfs, save mountInfo.json if d.enableKataCCMount && enableKataCCMount { if runtime.GOOS != "windows" && protocol != nfs {