diff --git a/pkg/extensionerrors/extensionerrors.go b/pkg/extensionerrors/extensionerrors.go index c2ca68f..7465096 100644 --- a/pkg/extensionerrors/extensionerrors.go +++ b/pkg/extensionerrors/extensionerrors.go @@ -44,4 +44,26 @@ var ( ErrNotFound = errors.New("NotFound") ErrInvalidOperationName = errors.New("operation name is invalid") + + ErrMissingPolicyFile = errors.New("policy file is missing") + + ErrInvalidPolicyFile = errors.New("policy file is invalid") + + ErrEmptyPolicyFile = errors.New("policy file is empty") + + ErrEmptyPolicyFilePath = errors.New("the path to the policy file cannot be empty") + + ErrFailedToUnmarshalPolicyFile = errors.New("failed to unmarshal policy file") + + ErrPolicyNotYetLoaded = errors.New("policy settings have not yet been loaded") + + ErrPolicyValidationFailed = errors.New("policy validation failed") + + ErrPolicyAllowlistEmpty = errors.New("file is not in allowlist because the allowlist is empty") + + ErrItemNotInAllowlist = errors.New("item is not in the allowlist") + + ErrEmptyFilepathToValidate = errors.New("filepath of the file to validate cannot be empty") + + ErrFailedToReadFileToValidate = errors.New("failed to read file to validate") ) diff --git a/pkg/extensionpolicysettings/extensionpolicysettings.go b/pkg/extensionpolicysettings/extensionpolicysettings.go new file mode 100644 index 0000000..4e7eeea --- /dev/null +++ b/pkg/extensionpolicysettings/extensionpolicysettings.go @@ -0,0 +1,147 @@ +package extensionpolicysettings + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "os" + + "github.com/Azure/azure-extension-platform/pkg/extensionerrors" +) + +type ExtensionPolicySettings interface { + ValidateFormat() error +} + +type ExtensionPolicySettingsManager[T ExtensionPolicySettings] struct { + settingsFilePath string + settings *T +} + +func NewExtensionPolicySettingsManager[T ExtensionPolicySettings](policyFilePath string) (*ExtensionPolicySettingsManager[T], error) { + if policyFilePath == "" { + return nil, extensionerrors.ErrEmptyPolicyFilePath + } + return &ExtensionPolicySettingsManager[T]{ + settingsFilePath: policyFilePath, + }, nil +} + +func (epsm *ExtensionPolicySettingsManager[T]) LoadExtensionPolicySettings() error { + if epsm == nil { + return fmt.Errorf("invalid ExtensionPolicySettingsManager: manager is nil") + } + if epsm.settingsFilePath == "" { + return extensionerrors.ErrEmptyPolicyFilePath + } + + // If an extension has a default policy configuration in case the file does not exist, they should handle that logic before calling this function. + if _, err := os.Stat(epsm.settingsFilePath); os.IsNotExist(err) { + return extensionerrors.ErrMissingPolicyFile + } else if err != nil { + return fmt.Errorf("error checking extension policy settings file: %w", err) + } + + fileContent, err := os.ReadFile(epsm.settingsFilePath) + if err != nil { + return fmt.Errorf("failed to read extension policy settings file: %w", err) // TODO: Add retry logic if appropriate. + } + + if len(fileContent) == 0 { + return extensionerrors.ErrEmptyPolicyFile + } + + var settings *T = new(T) + if err := json.Unmarshal(fileContent, settings); err != nil { + return fmt.Errorf("failed to unmarshal extension policy settings: %w", err) + } + + // Extensions themselves must decide the criteria for valid policy settings (i.e., if they can be null etc.). + if err := (*settings).ValidateFormat(); err != nil { + return fmt.Errorf("extension policy loaded, but invalid format: %w", err) + } + + epsm.settings = settings + return nil +} + +func (epsm *ExtensionPolicySettingsManager[T]) GetSettings() (*T, error) { + if epsm.settings == nil { + return nil, extensionerrors.ErrPolicyNotYetLoaded + } + return epsm.settings, nil +} + +// Validation Helper Functions +type HashType int + +const ( + HashTypeNone HashType = iota + HashTypeSHA1 + HashTypeSHA256 +) + +func ValidateValueInAllowlist(value string, allowlist []string) error { + if len(allowlist) == 0 { + return extensionerrors.ErrPolicyAllowlistEmpty + } + + for _, allowlistValue := range allowlist { + if value == allowlistValue { + return nil + } + } + return extensionerrors.ErrItemNotInAllowlist +} + +// This function is the entry point for most use cases: it takes in the filepath, reads the content, and +// determines if the content is allowlisted. If hashOpt is not HashTypeNone, it will compute the hash of the file content. +// If extensions don't want to validate a filepath but a value directly, they can call ValidateValueInAllowlist, +// which this function calls. +func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt HashType) error { + if len(allowlist) == 0 { + return extensionerrors.ErrPolicyAllowlistEmpty + } + + if filePath == "" { + return extensionerrors.ErrEmptyFilepathToValidate + } + + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return fmt.Errorf("file to validate does not exist: %w", err) + } + + content, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read file %s for validation: %w", filePath, err) + } + + value := string(content) + + if hashOpt != HashTypeNone { + value, err := ComputeFileHash(value, hashOpt) + if err != nil { + return fmt.Errorf("error occured when hashing contents of file %s for validation: %w", filePath, err) + } + return ValidateValueInAllowlist(value, allowlist) + } + + return ValidateValueInAllowlist(value, allowlist) +} + +// ComputeFileHash computes the hash of a file or leaves string as is. +func ComputeFileHash(contents string, hashOpt HashType) (string, error) { + var hashStr string + switch hashOpt { + case HashTypeSHA1: + hash := sha1.Sum([]byte(contents)) + hashStr = hex.EncodeToString(hash[:]) + default: + hash := sha256.Sum256([]byte(contents)) + hashStr = hex.EncodeToString(hash[:]) + } + + return hashStr, nil +} diff --git a/pkg/extensionpolicysettings/extensionpolicysettings_test.go b/pkg/extensionpolicysettings/extensionpolicysettings_test.go new file mode 100644 index 0000000..836a6b3 --- /dev/null +++ b/pkg/extensionpolicysettings/extensionpolicysettings_test.go @@ -0,0 +1,196 @@ +// filepath: /home/anasanc/repos/azure-extension-platform/pkg/extensionpolicysettings/extensionpolicysettings_test.go +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package extensionpolicysettings + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "testing" + + "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "github.com/stretchr/testify/require" +) + +const extensionRuntimePolicySettingsFilePath = "./testutils/runtime_policy.json" + +// This is a sample struct for an example extension's policy settings. +// Each extension will define their own struct that implements the ExtensionPolicySettings interface according to their needs. +type TestPolicy struct { + RequiresSigning string `json:"requireSigning"` + AllowedScripts []string `json:"allowedScripts"` +} + +func (tp TestPolicy) ValidateFormat() error { + // In a real extension, you would implement logic to validate the policy was correctly loaded. + return nil +} + +func TestNewExtensionPolicySettingsManager(t *testing.T) { + // Create a new ExtensionPolicySettingsManager + manager, err := NewExtensionPolicySettingsManager[TestPolicy](extensionRuntimePolicySettingsFilePath) + require.NoError(t, err) + require.NotNil(t, manager) + require.Equal(t, extensionRuntimePolicySettingsFilePath, manager.settingsFilePath) + require.Nil(t, manager.settings) // settings should not be loaded until LoadExtensionPolicySettings is called +} + +func TestLoadExtensionPolicySettings(t *testing.T) { + // Setup test parameters + manager, err := NewExtensionPolicySettingsManager[TestPolicy](extensionRuntimePolicySettingsFilePath) + require.NoError(t, err) + + // Test cases: + // 1. Valid policy file: we should be able to load the settings without error + validPolicyContent := `{ + "requireSigning": "true", + "allowedScripts": [] + }` + writeToFile(extensionRuntimePolicySettingsFilePath, validPolicyContent) + defer cleanupFile(extensionRuntimePolicySettingsFilePath) + + // Call LoadExtensionPolicySettings and check for errors + err = manager.LoadExtensionPolicySettings() + require.NoError(t, err) + require.NotNil(t, manager.settings) + require.Equal(t, "true", manager.settings.RequiresSigning) + require.Empty(t, manager.settings.AllowedScripts) + + // 2. Invalid policy file (e.g. not valid json): we should get an error when trying to load the settings + invalidPolicyContent := `{` + writeToFile(extensionRuntimePolicySettingsFilePath, invalidPolicyContent) + err = manager.LoadExtensionPolicySettings() + require.Error(t, err) + + // 3. Empty policy file: we should get an error indicating the policy file is empty + writeToFile(extensionRuntimePolicySettingsFilePath, "") + err = manager.LoadExtensionPolicySettings() + require.ErrorIs(t, err, extensionerrors.ErrEmptyPolicyFile) + + // 5. Locked policy file: we should get an error indicating the file cannot be accessed. + // modify the file permissions to simulate a locked file (read-only file) + os.Chmod(extensionRuntimePolicySettingsFilePath, 0200) // write-only permissions + err = manager.LoadExtensionPolicySettings() + require.Error(t, err) + + // 5. Missing policy file: we should get an error indicating the policy file is missing + cleanupFile(extensionRuntimePolicySettingsFilePath) + err = manager.LoadExtensionPolicySettings() + require.ErrorIs(t, err, extensionerrors.ErrMissingPolicyFile) +} + +func TestGetSettings(t *testing.T) { + // Setup test parameters + manager, err := NewExtensionPolicySettingsManager[TestPolicy](extensionRuntimePolicySettingsFilePath) + require.NoError(t, err) + validPolicyContent := `{ + "requireSigning": "true", + "allowedScripts": [] + }` + require.NoError(t, writeToFile(extensionRuntimePolicySettingsFilePath, validPolicyContent)) + defer cleanupFile(extensionRuntimePolicySettingsFilePath) + + // Call LoadExtensionPolicySettings and check for errors + _, err = manager.GetSettings() + require.ErrorIs(t, err, extensionerrors.ErrPolicyNotYetLoaded) // should return an error because settings have not been loaded yet + err = manager.LoadExtensionPolicySettings() + require.NoError(t, err) + require.NotNil(t, manager.settings) + require.Equal(t, "true", manager.settings.RequiresSigning) + + // Call GetSettings and check for errors + settings, err := manager.GetSettings() + require.NoError(t, err) + require.NotNil(t, settings) + require.Equal(t, "true", settings.RequiresSigning) + require.Empty(t, settings.AllowedScripts) +} + +func TestValidateAgainstAllowlist(t *testing.T) { + // Setup test parameters + manager, err := NewExtensionPolicySettingsManager[TestPolicy](extensionRuntimePolicySettingsFilePath) + require.NoError(t, err) + defer cleanupFile(extensionRuntimePolicySettingsFilePath) // Clean up after test + + script1Hash, err := hashHelper("./testutils/testscripts/script1.sh", TestHashTypeSha256) + require.NoError(t, err) + script2Hash, err := hashHelper("./testutils/testscripts/script2.sh", TestHashTypeSha256) + require.NoError(t, err) + // Skip computing script3 hash because it will not be allowed.. + script4Hash := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // pre-computed hash of the empty string + script5Hash, err := hashHelper("./testutils/testscripts/script5.sh", TestHashTypeSha1) + require.NoError(t, err) + + // Some scripts are allowed + validPolicyContent := fmt.Sprintf(`{ + "requireSigning": "true", + "allowedScripts": ["%s", "%s", "%s", "%s"] + }`, script1Hash, script2Hash, script4Hash, script5Hash) + require.NoError(t, writeToFile(extensionRuntimePolicySettingsFilePath, validPolicyContent)) + + // Call LoadExtensionPolicySettings and check for errors + err = manager.LoadExtensionPolicySettings() + require.NoError(t, err) + require.NotNil(t, manager.settings) + require.Equal(t, "true", manager.settings.RequiresSigning) + require.NotEmpty(t, manager.settings.AllowedScripts) + + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", manager.settings.AllowedScripts, HashTypeSHA256)) + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script2.sh", manager.settings.AllowedScripts, HashTypeSHA256)) + require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script3.sh", manager.settings.AllowedScripts, HashTypeSHA256), extensionerrors.ErrItemNotInAllowlist) + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script5.sh", manager.settings.AllowedScripts, HashTypeSHA1)) + + // Empty filepath + require.ErrorIs(t, ValidateFileHashInAllowlist("", manager.settings.AllowedScripts, HashTypeSHA256), extensionerrors.ErrEmptyFilepathToValidate) + // Missing file + require.Error(t, ValidateFileHashInAllowlist("./testutils/testscripts/missing.sh", manager.settings.AllowedScripts, HashTypeSHA256)) + // Now, empty list. + require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", []string{}, HashTypeSHA256), extensionerrors.ErrPolicyAllowlistEmpty) + // Empty file + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script4.sh", manager.settings.AllowedScripts, HashTypeSHA256)) + +} + +// Helper functions for tests + +func writeToFile(filePath, content string) error { + err := os.WriteFile(filePath, []byte(content), 0644) + return err +} + +func cleanupFile(path string) { + if _, err := os.Stat(path); err == nil { + os.Remove(path) + } +} + +type TestHashType int + +const ( + TestHashTypeSha1 TestHashType = iota + TestHashTypeSha256 +) + +func hashHelper(filePath string, hashOpt TestHashType) (string, error) { + contents, err := os.ReadFile(filePath) + + if err != nil { + return "", err + } + + var hashStr string + switch hashOpt { + case TestHashTypeSha1: + hash := sha1.New() + hash.Write(contents) + hashStr = hex.EncodeToString(hash.Sum(nil)) + case TestHashTypeSha256: + hash := sha256.New() + hash.Write(contents) + hashStr = hex.EncodeToString(hash.Sum(nil)) + } + return hashStr, nil +} diff --git a/pkg/extensionpolicysettings/testutils/testscripts/script1.sh b/pkg/extensionpolicysettings/testutils/testscripts/script1.sh new file mode 100644 index 0000000..11a48d9 --- /dev/null +++ b/pkg/extensionpolicysettings/testutils/testscripts/script1.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# This is a simple shell script +echo "Hello, World! I am script1.sh" \ No newline at end of file diff --git a/pkg/extensionpolicysettings/testutils/testscripts/script2.sh b/pkg/extensionpolicysettings/testutils/testscripts/script2.sh new file mode 100644 index 0000000..937a424 --- /dev/null +++ b/pkg/extensionpolicysettings/testutils/testscripts/script2.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# This is a simple shell script +echo "Hello, World! I am script 2" \ No newline at end of file diff --git a/pkg/extensionpolicysettings/testutils/testscripts/script3.sh b/pkg/extensionpolicysettings/testutils/testscripts/script3.sh new file mode 100644 index 0000000..2d27484 --- /dev/null +++ b/pkg/extensionpolicysettings/testutils/testscripts/script3.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# This is a simple shell script +echo "I am a banned script." \ No newline at end of file diff --git a/pkg/extensionpolicysettings/testutils/testscripts/script4.sh b/pkg/extensionpolicysettings/testutils/testscripts/script4.sh new file mode 100644 index 0000000..e69de29 diff --git a/pkg/extensionpolicysettings/testutils/testscripts/script5.sh b/pkg/extensionpolicysettings/testutils/testscripts/script5.sh new file mode 100644 index 0000000..b73c937 --- /dev/null +++ b/pkg/extensionpolicysettings/testutils/testscripts/script5.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# This is a simple shell script +echo "Hello, World! I am script5.sh. I will be hashed in SHA1" \ No newline at end of file