From 0752e233529239be63d83d671a44b1a720bd6e1c Mon Sep 17 00:00:00 2001 From: Robert Kaussow Date: Sun, 12 May 2024 11:01:28 +0200 Subject: [PATCH] refactor: use dedicated aws package (#113) --- .mockery.yaml | 6 + Makefile | 4 +- aws/api.go | 24 ++ aws/aws.go | 44 ++ aws/cloudfront.go | 40 ++ aws/cloudfront_test.go | 86 ++++ aws/mocks/mock_CloudfrontAPIClient.go | 112 +++++ aws/mocks/mock_S3APIClient.go | 555 +++++++++++++++++++++++++ aws/s3.go | 389 ++++++++++++++++++ aws/s3_test.go | 562 ++++++++++++++++++++++++++ go.mod | 26 +- go.sum | 47 ++- plugin/aws.go | 444 -------------------- plugin/impl.go | 58 ++- plugin/plugin.go | 1 - 15 files changed, 1933 insertions(+), 465 deletions(-) create mode 100644 .mockery.yaml create mode 100644 aws/api.go create mode 100644 aws/aws.go create mode 100644 aws/cloudfront.go create mode 100644 aws/cloudfront_test.go create mode 100644 aws/mocks/mock_CloudfrontAPIClient.go create mode 100644 aws/mocks/mock_S3APIClient.go create mode 100644 aws/s3.go create mode 100644 aws/s3_test.go delete mode 100644 plugin/aws.go diff --git a/.mockery.yaml b/.mockery.yaml new file mode 100644 index 0000000..d040238 --- /dev/null +++ b/.mockery.yaml @@ -0,0 +1,6 @@ +--- +all: True +dir: "{{.PackageName}}/mocks" +outpkg: "mocks" +packages: + github.com/thegeeklab/wp-s3-action/aws: diff --git a/Makefile b/Makefile index e399805..69e7729 100644 --- a/Makefile +++ b/Makefile @@ -11,13 +11,14 @@ IMPORT := github.com/thegeeklab/$(EXECUTABLE) GO ?= go CWD ?= $(shell pwd) -PACKAGES ?= $(shell go list ./...) +PACKAGES ?= $(shell go list ./... | grep -Ev 'mocks') SOURCES ?= $(shell find . -name "*.go" -type f) GOFUMPT_PACKAGE ?= mvdan.cc/gofumpt@$(GOFUMPT_PACKAGE_VERSION) GOLANGCI_LINT_PACKAGE ?= github.com/golangci/golangci-lint/cmd/golangci-lint@$(GOLANGCI_LINT_PACKAGE_VERSION) XGO_PACKAGE ?= src.techknowlogick.com/xgo@latest GOTESTSUM_PACKAGE ?= gotest.tools/gotestsum@latest +MOCKERY_PACKAGE ?= github.com/vektra/mockery/v2@latest XGO_VERSION := go-1.22.x XGO_TARGETS ?= linux/amd64,linux/arm-6,linux/arm-7,linux/arm64 @@ -65,6 +66,7 @@ lint: golangci-lint .PHONY: generate generate: $(GO) generate $(PACKAGES) + $(GO) run $(MOCKERY_PACKAGE) .PHONY: test test: diff --git a/aws/api.go b/aws/api.go new file mode 100644 index 0000000..0ae0c6f --- /dev/null +++ b/aws/api.go @@ -0,0 +1,24 @@ +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/cloudfront" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +//nolint:lll +type S3APIClient interface { + HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) + PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) + GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + CopyObject(ctx context.Context, params *s3.CopyObjectInput, optFns ...func(*s3.Options)) (*s3.CopyObjectOutput, error) + GetObjectAcl(ctx context.Context, params *s3.GetObjectAclInput, optFns ...func(*s3.Options)) (*s3.GetObjectAclOutput, error) + DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) + ListObjects(ctx context.Context, params *s3.ListObjectsInput, optFns ...func(*s3.Options)) (*s3.ListObjectsOutput, error) +} + +//nolint:lll +type CloudfrontAPIClient interface { + CreateInvalidation(ctx context.Context, params *cloudfront.CreateInvalidationInput, optFns ...func(*cloudfront.Options)) (*cloudfront.CreateInvalidationOutput, error) +} diff --git a/aws/aws.go b/aws/aws.go new file mode 100644 index 0000000..0493c2a --- /dev/null +++ b/aws/aws.go @@ -0,0 +1,44 @@ +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/cloudfront" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +type Client struct { + S3 *S3 + Cloudfront *Cloudfront +} + +// NewClient creates a new S3 client with the provided configuration. +func NewClient(ctx context.Context, url, region, accessKey, secretKey string, pathStyle bool) (*Client, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return nil, fmt.Errorf("error while loading AWS config: %w", err) + } + + if url != "" { + cfg.BaseEndpoint = aws.String(url) + } + + // allowing to use the instance role or provide a key and secret + if accessKey != "" && secretKey != "" { + cfg.Credentials = credentials.NewStaticCredentialsProvider(accessKey, secretKey, "") + } + + c := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.UsePathStyle = pathStyle + }) + cf := cloudfront.NewFromConfig(cfg) + + return &Client{ + S3: &S3{client: c}, + Cloudfront: &Cloudfront{client: cf}, + }, nil +} diff --git a/aws/cloudfront.go b/aws/cloudfront.go new file mode 100644 index 0000000..6e27fc7 --- /dev/null +++ b/aws/cloudfront.go @@ -0,0 +1,40 @@ +package aws + +import ( + "context" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudfront" + "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/rs/zerolog/log" +) + +type Cloudfront struct { + client CloudfrontAPIClient + Distribution string +} + +type CloudfrontInvalidateOptions struct { + Path string +} + +// Invalidate invalidates the specified path in the CloudFront distribution. +func (c *Cloudfront) Invalidate(ctx context.Context, opt CloudfrontInvalidateOptions) error { + log.Debug().Msgf("invalidating '%s'", opt.Path) + + _, err := c.client.CreateInvalidation(ctx, &cloudfront.CreateInvalidationInput{ + DistributionId: aws.String(c.Distribution), + InvalidationBatch: &types.InvalidationBatch{ + CallerReference: aws.String(time.Now().Format(time.RFC3339Nano)), + Paths: &types.Paths{ + Quantity: aws.Int32(1), + Items: []string{ + opt.Path, + }, + }, + }, + }) + + return err +} diff --git a/aws/cloudfront_test.go b/aws/cloudfront_test.go new file mode 100644 index 0000000..5572065 --- /dev/null +++ b/aws/cloudfront_test.go @@ -0,0 +1,86 @@ +package aws + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/cloudfront" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/thegeeklab/wp-s3-action/aws/mocks" +) + +var ErrCreateInvalidation = errors.New("create invalidation failed") + +func TestCloudfront_Invalidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) (*Cloudfront, CloudfrontInvalidateOptions, func()) + wantErr bool + }{ + { + name: "invalidate path successfully", + setup: func(t *testing.T) (*Cloudfront, CloudfrontInvalidateOptions, func()) { + t.Helper() + + mockClient := mocks.NewMockCloudfrontAPIClient(t) + mockClient. + On("CreateInvalidation", mock.Anything, mock.Anything). + Return(&cloudfront.CreateInvalidationOutput{}, nil) + + return &Cloudfront{ + client: mockClient, + Distribution: "test-distribution", + }, CloudfrontInvalidateOptions{ + Path: "/path/to/invalidate", + }, func() { + mockClient.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "error when create invalidation fails", + setup: func(t *testing.T) (*Cloudfront, CloudfrontInvalidateOptions, func()) { + t.Helper() + + mockClient := mocks.NewMockCloudfrontAPIClient(t) + mockClient. + On("CreateInvalidation", mock.Anything, mock.Anything). + Return(&cloudfront.CreateInvalidationOutput{}, ErrCreateInvalidation) + + return &Cloudfront{ + client: mockClient, + Distribution: "test-distribution", + }, CloudfrontInvalidateOptions{ + Path: "/path/to/invalidate", + }, func() { + mockClient.AssertExpectations(t) + } + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cf, opt, teardown := tt.setup(t) + defer teardown() + + err := cf.Invalidate(context.Background(), opt) + if tt.wantErr { + assert.Error(t, err) + + return + } + + assert.NoError(t, err) + }) + } +} diff --git a/aws/mocks/mock_CloudfrontAPIClient.go b/aws/mocks/mock_CloudfrontAPIClient.go new file mode 100644 index 0000000..8e45d5d --- /dev/null +++ b/aws/mocks/mock_CloudfrontAPIClient.go @@ -0,0 +1,112 @@ +// Code generated by mockery v2.43.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + cloudfront "github.com/aws/aws-sdk-go-v2/service/cloudfront" + + mock "github.com/stretchr/testify/mock" +) + +// MockCloudfrontAPIClient is an autogenerated mock type for the CloudfrontAPIClient type +type MockCloudfrontAPIClient struct { + mock.Mock +} + +type MockCloudfrontAPIClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCloudfrontAPIClient) EXPECT() *MockCloudfrontAPIClient_Expecter { + return &MockCloudfrontAPIClient_Expecter{mock: &_m.Mock} +} + +// CreateInvalidation provides a mock function with given fields: ctx, params, optFns +func (_m *MockCloudfrontAPIClient) CreateInvalidation(ctx context.Context, params *cloudfront.CreateInvalidationInput, optFns ...func(*cloudfront.Options)) (*cloudfront.CreateInvalidationOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CreateInvalidation") + } + + var r0 *cloudfront.CreateInvalidationOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *cloudfront.CreateInvalidationInput, ...func(*cloudfront.Options)) (*cloudfront.CreateInvalidationOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *cloudfront.CreateInvalidationInput, ...func(*cloudfront.Options)) *cloudfront.CreateInvalidationOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cloudfront.CreateInvalidationOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *cloudfront.CreateInvalidationInput, ...func(*cloudfront.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCloudfrontAPIClient_CreateInvalidation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateInvalidation' +type MockCloudfrontAPIClient_CreateInvalidation_Call struct { + *mock.Call +} + +// CreateInvalidation is a helper method to define mock.On call +// - ctx context.Context +// - params *cloudfront.CreateInvalidationInput +// - optFns ...func(*cloudfront.Options) +func (_e *MockCloudfrontAPIClient_Expecter) CreateInvalidation(ctx interface{}, params interface{}, optFns ...interface{}) *MockCloudfrontAPIClient_CreateInvalidation_Call { + return &MockCloudfrontAPIClient_CreateInvalidation_Call{Call: _e.mock.On("CreateInvalidation", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockCloudfrontAPIClient_CreateInvalidation_Call) Run(run func(ctx context.Context, params *cloudfront.CreateInvalidationInput, optFns ...func(*cloudfront.Options))) *MockCloudfrontAPIClient_CreateInvalidation_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*cloudfront.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*cloudfront.Options)) + } + } + run(args[0].(context.Context), args[1].(*cloudfront.CreateInvalidationInput), variadicArgs...) + }) + return _c +} + +func (_c *MockCloudfrontAPIClient_CreateInvalidation_Call) Return(_a0 *cloudfront.CreateInvalidationOutput, _a1 error) *MockCloudfrontAPIClient_CreateInvalidation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCloudfrontAPIClient_CreateInvalidation_Call) RunAndReturn(run func(context.Context, *cloudfront.CreateInvalidationInput, ...func(*cloudfront.Options)) (*cloudfront.CreateInvalidationOutput, error)) *MockCloudfrontAPIClient_CreateInvalidation_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCloudfrontAPIClient creates a new instance of MockCloudfrontAPIClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCloudfrontAPIClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCloudfrontAPIClient { + mock := &MockCloudfrontAPIClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/aws/mocks/mock_S3APIClient.go b/aws/mocks/mock_S3APIClient.go new file mode 100644 index 0000000..50fad2c --- /dev/null +++ b/aws/mocks/mock_S3APIClient.go @@ -0,0 +1,555 @@ +// Code generated by mockery v2.43.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + s3 "github.com/aws/aws-sdk-go-v2/service/s3" + mock "github.com/stretchr/testify/mock" +) + +// MockS3APIClient is an autogenerated mock type for the S3APIClient type +type MockS3APIClient struct { + mock.Mock +} + +type MockS3APIClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockS3APIClient) EXPECT() *MockS3APIClient_Expecter { + return &MockS3APIClient_Expecter{mock: &_m.Mock} +} + +// CopyObject provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) CopyObject(ctx context.Context, params *s3.CopyObjectInput, optFns ...func(*s3.Options)) (*s3.CopyObjectOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CopyObject") + } + + var r0 *s3.CopyObjectOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.CopyObjectInput, ...func(*s3.Options)) (*s3.CopyObjectOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.CopyObjectInput, ...func(*s3.Options)) *s3.CopyObjectOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.CopyObjectOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.CopyObjectInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_CopyObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CopyObject' +type MockS3APIClient_CopyObject_Call struct { + *mock.Call +} + +// CopyObject is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.CopyObjectInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) CopyObject(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_CopyObject_Call { + return &MockS3APIClient_CopyObject_Call{Call: _e.mock.On("CopyObject", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_CopyObject_Call) Run(run func(ctx context.Context, params *s3.CopyObjectInput, optFns ...func(*s3.Options))) *MockS3APIClient_CopyObject_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.CopyObjectInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_CopyObject_Call) Return(_a0 *s3.CopyObjectOutput, _a1 error) *MockS3APIClient_CopyObject_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_CopyObject_Call) RunAndReturn(run func(context.Context, *s3.CopyObjectInput, ...func(*s3.Options)) (*s3.CopyObjectOutput, error)) *MockS3APIClient_CopyObject_Call { + _c.Call.Return(run) + return _c +} + +// DeleteObject provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for DeleteObject") + } + + var r0 *s3.DeleteObjectOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.DeleteObjectInput, ...func(*s3.Options)) (*s3.DeleteObjectOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.DeleteObjectInput, ...func(*s3.Options)) *s3.DeleteObjectOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.DeleteObjectOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.DeleteObjectInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_DeleteObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteObject' +type MockS3APIClient_DeleteObject_Call struct { + *mock.Call +} + +// DeleteObject is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.DeleteObjectInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) DeleteObject(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_DeleteObject_Call { + return &MockS3APIClient_DeleteObject_Call{Call: _e.mock.On("DeleteObject", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_DeleteObject_Call) Run(run func(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options))) *MockS3APIClient_DeleteObject_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.DeleteObjectInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_DeleteObject_Call) Return(_a0 *s3.DeleteObjectOutput, _a1 error) *MockS3APIClient_DeleteObject_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_DeleteObject_Call) RunAndReturn(run func(context.Context, *s3.DeleteObjectInput, ...func(*s3.Options)) (*s3.DeleteObjectOutput, error)) *MockS3APIClient_DeleteObject_Call { + _c.Call.Return(run) + return _c +} + +// GetObject provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetObject") + } + + var r0 *s3.GetObjectOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) *s3.GetObjectOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.GetObjectOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_GetObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetObject' +type MockS3APIClient_GetObject_Call struct { + *mock.Call +} + +// GetObject is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.GetObjectInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) GetObject(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_GetObject_Call { + return &MockS3APIClient_GetObject_Call{Call: _e.mock.On("GetObject", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_GetObject_Call) Run(run func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options))) *MockS3APIClient_GetObject_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.GetObjectInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_GetObject_Call) Return(_a0 *s3.GetObjectOutput, _a1 error) *MockS3APIClient_GetObject_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_GetObject_Call) RunAndReturn(run func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)) *MockS3APIClient_GetObject_Call { + _c.Call.Return(run) + return _c +} + +// GetObjectAcl provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) GetObjectAcl(ctx context.Context, params *s3.GetObjectAclInput, optFns ...func(*s3.Options)) (*s3.GetObjectAclOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetObjectAcl") + } + + var r0 *s3.GetObjectAclOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectAclInput, ...func(*s3.Options)) (*s3.GetObjectAclOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectAclInput, ...func(*s3.Options)) *s3.GetObjectAclOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.GetObjectAclOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectAclInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_GetObjectAcl_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetObjectAcl' +type MockS3APIClient_GetObjectAcl_Call struct { + *mock.Call +} + +// GetObjectAcl is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.GetObjectAclInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) GetObjectAcl(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_GetObjectAcl_Call { + return &MockS3APIClient_GetObjectAcl_Call{Call: _e.mock.On("GetObjectAcl", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_GetObjectAcl_Call) Run(run func(ctx context.Context, params *s3.GetObjectAclInput, optFns ...func(*s3.Options))) *MockS3APIClient_GetObjectAcl_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.GetObjectAclInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_GetObjectAcl_Call) Return(_a0 *s3.GetObjectAclOutput, _a1 error) *MockS3APIClient_GetObjectAcl_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_GetObjectAcl_Call) RunAndReturn(run func(context.Context, *s3.GetObjectAclInput, ...func(*s3.Options)) (*s3.GetObjectAclOutput, error)) *MockS3APIClient_GetObjectAcl_Call { + _c.Call.Return(run) + return _c +} + +// HeadObject provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for HeadObject") + } + + var r0 *s3.HeadObjectOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.HeadObjectInput, ...func(*s3.Options)) (*s3.HeadObjectOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.HeadObjectInput, ...func(*s3.Options)) *s3.HeadObjectOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.HeadObjectOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.HeadObjectInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_HeadObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HeadObject' +type MockS3APIClient_HeadObject_Call struct { + *mock.Call +} + +// HeadObject is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.HeadObjectInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) HeadObject(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_HeadObject_Call { + return &MockS3APIClient_HeadObject_Call{Call: _e.mock.On("HeadObject", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_HeadObject_Call) Run(run func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options))) *MockS3APIClient_HeadObject_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.HeadObjectInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_HeadObject_Call) Return(_a0 *s3.HeadObjectOutput, _a1 error) *MockS3APIClient_HeadObject_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_HeadObject_Call) RunAndReturn(run func(context.Context, *s3.HeadObjectInput, ...func(*s3.Options)) (*s3.HeadObjectOutput, error)) *MockS3APIClient_HeadObject_Call { + _c.Call.Return(run) + return _c +} + +// ListObjects provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) ListObjects(ctx context.Context, params *s3.ListObjectsInput, optFns ...func(*s3.Options)) (*s3.ListObjectsOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for ListObjects") + } + + var r0 *s3.ListObjectsOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.ListObjectsInput, ...func(*s3.Options)) (*s3.ListObjectsOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.ListObjectsInput, ...func(*s3.Options)) *s3.ListObjectsOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.ListObjectsOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.ListObjectsInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_ListObjects_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListObjects' +type MockS3APIClient_ListObjects_Call struct { + *mock.Call +} + +// ListObjects is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.ListObjectsInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) ListObjects(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_ListObjects_Call { + return &MockS3APIClient_ListObjects_Call{Call: _e.mock.On("ListObjects", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_ListObjects_Call) Run(run func(ctx context.Context, params *s3.ListObjectsInput, optFns ...func(*s3.Options))) *MockS3APIClient_ListObjects_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.ListObjectsInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_ListObjects_Call) Return(_a0 *s3.ListObjectsOutput, _a1 error) *MockS3APIClient_ListObjects_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_ListObjects_Call) RunAndReturn(run func(context.Context, *s3.ListObjectsInput, ...func(*s3.Options)) (*s3.ListObjectsOutput, error)) *MockS3APIClient_ListObjects_Call { + _c.Call.Return(run) + return _c +} + +// PutObject provides a mock function with given fields: ctx, params, optFns +func (_m *MockS3APIClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + _va := make([]interface{}, len(optFns)) + for _i := range optFns { + _va[_i] = optFns[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, params) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for PutObject") + } + + var r0 *s3.PutObjectOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error)); ok { + return rf(ctx, params, optFns...) + } + if rf, ok := ret.Get(0).(func(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) *s3.PutObjectOutput); ok { + r0 = rf(ctx, params, optFns...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*s3.PutObjectOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) error); ok { + r1 = rf(ctx, params, optFns...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockS3APIClient_PutObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PutObject' +type MockS3APIClient_PutObject_Call struct { + *mock.Call +} + +// PutObject is a helper method to define mock.On call +// - ctx context.Context +// - params *s3.PutObjectInput +// - optFns ...func(*s3.Options) +func (_e *MockS3APIClient_Expecter) PutObject(ctx interface{}, params interface{}, optFns ...interface{}) *MockS3APIClient_PutObject_Call { + return &MockS3APIClient_PutObject_Call{Call: _e.mock.On("PutObject", + append([]interface{}{ctx, params}, optFns...)...)} +} + +func (_c *MockS3APIClient_PutObject_Call) Run(run func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options))) *MockS3APIClient_PutObject_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*s3.Options), len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(func(*s3.Options)) + } + } + run(args[0].(context.Context), args[1].(*s3.PutObjectInput), variadicArgs...) + }) + return _c +} + +func (_c *MockS3APIClient_PutObject_Call) Return(_a0 *s3.PutObjectOutput, _a1 error) *MockS3APIClient_PutObject_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockS3APIClient_PutObject_Call) RunAndReturn(run func(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error)) *MockS3APIClient_PutObject_Call { + _c.Call.Return(run) + return _c +} + +// NewMockS3APIClient creates a new instance of MockS3APIClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockS3APIClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockS3APIClient { + mock := &MockS3APIClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/aws/s3.go b/aws/s3.go new file mode 100644 index 0000000..0c752b9 --- /dev/null +++ b/aws/s3.go @@ -0,0 +1,389 @@ +package aws + +import ( + "context" + "crypto/md5" //nolint:gosec + "errors" + "fmt" + "io" + "mime" + "os" + "path/filepath" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/rs/zerolog/log" +) + +type S3 struct { + client S3APIClient + Bucket string + DryRun bool +} + +type S3UploadOptions struct { + LocalFilePath string + RemoteObjectKey string + ACL map[string]string + ContentType map[string]string + ContentEncoding map[string]string + CacheControl map[string]string + Metadata map[string]map[string]string +} + +type S3RedirectOptions struct { + Path string + Location string +} + +type S3DeleteOptions struct { + RemoteObjectKey string +} + +type S3ListOptions struct { + Path string +} + +// Upload uploads a file to an S3 bucket. It first checks if the file already exists in the bucket +// and compares the local file's content and metadata with the remote file. If the file has changed, +// it updates the remote file's metadata. If the file does not exist or has changed, +// it uploads the local file to the remote bucket. +func (u *S3) Upload(ctx context.Context, opt S3UploadOptions) error { + if opt.LocalFilePath == "" { + return nil + } + + file, err := os.Open(opt.LocalFilePath) + if err != nil { + return err + } + defer file.Close() + + acl := getACL(opt.LocalFilePath, opt.ACL) + contentType := getContentType(opt.LocalFilePath, opt.ContentType) + contentEncoding := getContentEncoding(opt.LocalFilePath, opt.ContentEncoding) + cacheControl := getCacheControl(opt.LocalFilePath, opt.CacheControl) + metadata := getMetadata(opt.LocalFilePath, opt.Metadata) + + head, err := u.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: &u.Bucket, + Key: &opt.RemoteObjectKey, + }) + if err != nil { + var noSuchKeyError *types.NoSuchKey + if !errors.As(err, &noSuchKeyError) { + return err + } + + log.Debug().Msgf( + "'%s' not found in bucket, uploading with content-type '%s' and permissions '%s'", + opt.LocalFilePath, + contentType, + acl, + ) + + if u.DryRun { + return nil + } + + _, err = u.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &u.Bucket, + Key: &opt.RemoteObjectKey, + Body: file, + ContentType: &contentType, + ACL: types.ObjectCannedACL(acl), + Metadata: metadata, + CacheControl: &cacheControl, + ContentEncoding: &contentEncoding, + }) + + return err + } + + //nolint:gosec + hash := md5.New() + _, _ = io.Copy(hash, file) + sum := fmt.Sprintf("'%x'", hash.Sum(nil)) + + if sum == *head.ETag { + shouldCopy, reason := u.shouldCopyObject( + ctx, head, opt.LocalFilePath, opt.RemoteObjectKey, contentType, acl, contentEncoding, cacheControl, metadata, + ) + if !shouldCopy { + log.Debug().Msgf("skipping '%s' because hashes and metadata match", opt.LocalFilePath) + + return nil + } + + log.Debug().Msgf("updating metadata for '%s' %s", opt.LocalFilePath, reason) + + if u.DryRun { + return nil + } + + _, err = u.client.CopyObject(ctx, &s3.CopyObjectInput{ + Bucket: &u.Bucket, + Key: &opt.RemoteObjectKey, + CopySource: aws.String(fmt.Sprintf("%s/%s", u.Bucket, opt.RemoteObjectKey)), + ACL: types.ObjectCannedACL(acl), + ContentType: &contentType, + Metadata: metadata, + MetadataDirective: types.MetadataDirectiveReplace, + CacheControl: &cacheControl, + ContentEncoding: &contentEncoding, + }) + + return err + } + + _, err = file.Seek(0, 0) + if err != nil { + return err + } + + log.Debug().Msgf("uploading '%s' with content-type '%s' and permissions '%s'", opt.LocalFilePath, contentType, acl) + + if u.DryRun { + return nil + } + + _, err = u.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &u.Bucket, + Key: &opt.RemoteObjectKey, + Body: file, + ContentType: &contentType, + ACL: types.ObjectCannedACL(acl), + Metadata: metadata, + CacheControl: &cacheControl, + ContentEncoding: &contentEncoding, + }) + + return err +} + +// shouldCopyObject determines whether an S3 object should be copied based on changes in content type, +// content encoding, cache control, and metadata. It compares the existing object's metadata with the +// provided metadata and returns a boolean indicating whether the object should be copied, +// along with a string describing the reason for the copy if applicable. +// +//nolint:gocognit +func (u *S3) shouldCopyObject( + ctx context.Context, head *s3.HeadObjectOutput, + local, remote, contentType, acl, contentEncoding, cacheControl string, + metadata map[string]string, +) (bool, string) { + var reason string + + if head.ContentType == nil && contentType != "" { + reason = fmt.Sprintf("content-type has changed from unset to %s", contentType) + + return true, reason + } + + if head.ContentType != nil && contentType != *head.ContentType { + reason = fmt.Sprintf("content-type has changed from %s to %s", *head.ContentType, contentType) + + return true, reason + } + + if head.ContentEncoding == nil && contentEncoding != "" { + reason = fmt.Sprintf("Content-Encoding has changed from unset to %s", contentEncoding) + + return true, reason + } + + if head.ContentEncoding != nil && contentEncoding != *head.ContentEncoding { + reason = fmt.Sprintf("Content-Encoding has changed from %s to %s", *head.ContentEncoding, contentEncoding) + + return true, reason + } + + if head.CacheControl == nil && cacheControl != "" { + reason = fmt.Sprintf("cache-control has changed from unset to %s", cacheControl) + + return true, reason + } + + if head.CacheControl != nil && cacheControl != *head.CacheControl { + reason = fmt.Sprintf("cache-control has changed from %s to %s", *head.CacheControl, cacheControl) + + return true, reason + } + + if len(head.Metadata) != len(metadata) { + reason = fmt.Sprintf("count of metadata values has changed for %s", local) + + return true, reason + } + + if len(metadata) > 0 { + for k, v := range metadata { + if hv, ok := head.Metadata[k]; ok { + if v != hv { + reason = fmt.Sprintf("metadata values have changed for %s", remote) + + return true, reason + } + } + } + } + + grant, err := u.client.GetObjectAcl(ctx, &s3.GetObjectAclInput{ + Bucket: &u.Bucket, + Key: &remote, + }) + if err != nil { + return false, "" + } + + previousACL := "private" + + for _, g := range grant.Grants { + grantee := g.Grantee + if grantee.URI != nil { + switch *grantee.URI { + case "http://acs.amazonaws.com/groups/global/AllUsers": + if g.Permission == "READ" { + previousACL = "public-read" + } else if g.Permission == "WRITE" { + previousACL = "public-read-write" + } + case "http://acs.amazonaws.com/groups/global/AuthenticatedUsers": + if g.Permission == "READ" { + previousACL = "authenticated-read" + } + } + } + } + + if previousACL != acl { + reason = fmt.Sprintf("permissions for '%s' have changed from '%s' to '%s'", remote, previousACL, acl) + + return true, reason + } + + return false, "" +} + +// getACL returns the ACL for the given file based on the provided patterns. +func getACL(file string, patterns map[string]string) string { + for pattern, acl := range patterns { + if match, _ := filepath.Match(pattern, file); match { + return acl + } + } + + return "private" +} + +// getContentType returns the content type for the given file based on the provided patterns. +func getContentType(file string, patterns map[string]string) string { + ext := filepath.Ext(file) + if contentType, ok := patterns[ext]; ok { + return contentType + } + + return mime.TypeByExtension(ext) +} + +// getContentEncoding returns the content encoding for the given file based on the provided patterns. +func getContentEncoding(file string, patterns map[string]string) string { + ext := filepath.Ext(file) + if contentEncoding, ok := patterns[ext]; ok { + return contentEncoding + } + + return "" +} + +// getCacheControl returns the cache control for the given file based on the provided patterns. +func getCacheControl(file string, patterns map[string]string) string { + for pattern, cacheControl := range patterns { + if match, _ := filepath.Match(pattern, file); match { + return cacheControl + } + } + + return "" +} + +// getMetadata returns the metadata for the given file based on the provided patterns. +func getMetadata(file string, patterns map[string]map[string]string) map[string]string { + metadata := make(map[string]string) + + for pattern, meta := range patterns { + if match, _ := filepath.Match(pattern, file); match { + for k, v := range meta { + metadata[k] = v + } + + break + } + } + + return metadata +} + +// Redirect adds a redirect from the specified path to the specified location in the S3 bucket. +func (u *S3) Redirect(ctx context.Context, opt S3RedirectOptions) error { + log.Debug().Msgf("adding redirect from '%s' to '%s'", opt.Path, opt.Location) + + if u.DryRun { + return nil + } + + _, err := u.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(u.Bucket), + Key: aws.String(opt.Path), + ACL: types.ObjectCannedACLPublicRead, + WebsiteRedirectLocation: aws.String(opt.Location), + }) + + return err +} + +// Delete removes the specified object from the S3 bucket. +func (u *S3) Delete(ctx context.Context, opt S3DeleteOptions) error { + log.Debug().Msgf("removing remote file '%s'", opt.RemoteObjectKey) + + if u.DryRun { + return nil + } + + _, err := u.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(u.Bucket), + Key: aws.String(opt.RemoteObjectKey), + }) + + return err +} + +// List retrieves a list of object keys in the S3 bucket under the specified path. +func (u *S3) List(ctx context.Context, opt S3ListOptions) ([]string, error) { + var remote []string + + input := &s3.ListObjectsInput{ + Bucket: aws.String(u.Bucket), + Prefix: aws.String(opt.Path), + } + + for { + resp, err := u.client.ListObjects(ctx, input) + if err != nil { + return remote, err + } + + for _, item := range resp.Contents { + remote = append(remote, *item.Key) + } + + if !*resp.IsTruncated { + break + } + + input.Marker = aws.String(remote[len(remote)-1]) + } + + return remote, nil +} diff --git a/aws/s3_test.go b/aws/s3_test.go new file mode 100644 index 0000000..02eb176 --- /dev/null +++ b/aws/s3_test.go @@ -0,0 +1,562 @@ +package aws + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/thegeeklab/wp-s3-action/aws/mocks" +) + +var ( + ErrPutObject = errors.New("put object failed") + ErrDeleteObject = errors.New("delete object failed") + ErrListObjects = errors.New("list objects failed") +) + +func createTempFile(t *testing.T, name string) string { + t.Helper() + + name = filepath.Join(t.TempDir(), name) + _ = os.WriteFile(name, []byte("hello"), 0o600) + + return name +} + +func TestS3_Upload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) (*S3, S3UploadOptions, func()) + wantErr bool + }{ + { + name: "skip upload when local is empty", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + return &S3{}, + S3UploadOptions{ + LocalFilePath: "", + }, func() {} + }, + wantErr: false, + }, + { + name: "error when local file does not exist", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + return &S3{}, + S3UploadOptions{ + LocalFilePath: "/path/to/non-existent/file", + }, func() {} + }, + wantErr: true, + }, + { + name: "upload new file with default acl and content type", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{}) + mockS3Client.On("PutObject", mock.Anything, mock.Anything).Return(&s3.PutObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file.txt"), + RemoteObjectKey: "remote/path/file.txt", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "update metadata when content type changed", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{ + ETag: aws.String("'5d41402abc4b2a76b9719d911017c592'"), + ContentType: aws.String("application/octet-stream"), + }, nil) + mockS3Client.On("CopyObject", mock.Anything, mock.Anything).Return(&s3.CopyObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file.txt"), + RemoteObjectKey: "remote/path/file.txt", + ContentType: map[string]string{"*.txt": "text/plain"}, + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "update metadata when acl changed", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{ + ETag: aws.String("'5d41402abc4b2a76b9719d911017c592'"), + ContentType: aws.String("text/plain; charset=utf-8"), + }, nil) + mockS3Client.On("GetObjectAcl", mock.Anything, mock.Anything).Return(&s3.GetObjectAclOutput{ + Grants: []types.Grant{ + { + Grantee: &types.Grantee{ + URI: aws.String("http://acs.amazonaws.com/groups/global/AllUsers"), + }, + Permission: types.PermissionWrite, + }, + }, + }, nil) + mockS3Client.On("CopyObject", mock.Anything, mock.Anything).Return(&s3.CopyObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file.txt"), + RemoteObjectKey: "remote/path/file.txt", + ACL: map[string]string{"*.txt": "public-read"}, + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "update metadata when cache control changed", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{ + ETag: aws.String("'5d41402abc4b2a76b9719d911017c592'"), + ContentType: aws.String("text/plain; charset=utf-8"), + CacheControl: aws.String("max-age=0"), + }, nil) + mockS3Client.On("CopyObject", mock.Anything, mock.Anything).Return(&s3.CopyObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file.txt"), + RemoteObjectKey: "remote/path/file.txt", + CacheControl: map[string]string{"*.txt": "max-age=3600"}, + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "update metadata when content encoding changed", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{ + ETag: aws.String("'5d41402abc4b2a76b9719d911017c592'"), + ContentType: aws.String("text/plain; charset=utf-8"), + ContentEncoding: aws.String("identity"), + }, nil) + mockS3Client.On("CopyObject", mock.Anything, mock.Anything).Return(&s3.CopyObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file.txt"), + RemoteObjectKey: "remote/path/file.txt", + ContentEncoding: map[string]string{"*.txt": "gzip"}, + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "update metadata when metadata changed", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{ + ETag: aws.String("'5d41402abc4b2a76b9719d911017c592'"), + ContentType: aws.String("text/plain; charset=utf-8"), + Metadata: map[string]string{"key": "old-value"}, + }, nil) + mockS3Client.On("CopyObject", mock.Anything, mock.Anything).Return(&s3.CopyObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file.txt"), + RemoteObjectKey: "remote/path/file.txt", + Metadata: map[string]map[string]string{"*.txt": {"key": "value"}}, + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "upload new file when dry run is true", + setup: func(t *testing.T) (*S3, S3UploadOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("HeadObject", mock.Anything, mock.Anything).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{}) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + DryRun: true, + }, S3UploadOptions{ + LocalFilePath: createTempFile(t, "file1.txt"), + RemoteObjectKey: "remote/path/file1.txt", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s3, opt, teardown := tt.setup(t) + defer teardown() + + err := s3.Upload(context.Background(), opt) + if tt.wantErr { + assert.Error(t, err) + + return + } + + assert.NoError(t, err) + }) + } +} + +func TestS3_Redirect(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) (*S3, S3RedirectOptions, func()) + wantErr bool + }{ + { + name: "redirect with valid options", + setup: func(t *testing.T) (*S3, S3RedirectOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("PutObject", mock.Anything, mock.Anything).Return(&s3.PutObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3RedirectOptions{ + Path: "redirect/path", + Location: "https://example.com", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "skip redirect when dry run is true", + setup: func(t *testing.T) (*S3, S3RedirectOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + DryRun: true, + }, S3RedirectOptions{ + Path: "redirect/path", + Location: "https://example.com", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "error when put object fails", + setup: func(t *testing.T) (*S3, S3RedirectOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client. + On("PutObject", mock.Anything, mock.Anything). + Return(&s3.PutObjectOutput{}, ErrPutObject) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3RedirectOptions{ + Path: "redirect/path", + Location: "https://example.com", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s3, opt, teardown := tt.setup(t) + defer teardown() + + err := s3.Redirect(context.Background(), opt) + if tt.wantErr { + assert.Error(t, err) + + return + } + + assert.NoError(t, err) + }) + } +} + +func TestS3_Delete(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) (*S3, S3DeleteOptions, func()) + wantErr bool + }{ + { + name: "delete existing object", + setup: func(t *testing.T) (*S3, S3DeleteOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("DeleteObject", mock.Anything, mock.Anything).Return(&s3.DeleteObjectOutput{}, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3DeleteOptions{ + RemoteObjectKey: "path/to/file.txt", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "skip delete when dry run is true", + setup: func(t *testing.T) (*S3, S3DeleteOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + DryRun: true, + }, S3DeleteOptions{ + RemoteObjectKey: "path/to/file.txt", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + }, + { + name: "error when delete object fails", + setup: func(t *testing.T) (*S3, S3DeleteOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client. + On("DeleteObject", mock.Anything, mock.Anything). + Return(&s3.DeleteObjectOutput{}, ErrDeleteObject) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3DeleteOptions{ + RemoteObjectKey: "path/to/file.txt", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s3, opt, teardown := tt.setup(t) + defer teardown() + + err := s3.Delete(context.Background(), opt) + if tt.wantErr { + assert.Error(t, err) + + return + } + + assert.NoError(t, err) + }) + } +} + +func TestS3_List(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) (*S3, S3ListOptions, func()) + wantErr bool + want []string + }{ + { + name: "list objects in prefix", + setup: func(t *testing.T) (*S3, S3ListOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("ListObjects", mock.Anything, mock.Anything).Return(&s3.ListObjectsOutput{ + Contents: []types.Object{ + {Key: aws.String("prefix/file1.txt")}, + {Key: aws.String("prefix/file2.txt")}, + }, + IsTruncated: aws.Bool(false), + }, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3ListOptions{ + Path: "prefix/", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + want: []string{"prefix/file1.txt", "prefix/file2.txt"}, + }, + { + name: "list objects with pagination", + setup: func(t *testing.T) (*S3, S3ListOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client.On("ListObjects", mock.Anything, mock.MatchedBy(func(input *s3.ListObjectsInput) bool { + return input.Marker == nil + })).Return(&s3.ListObjectsOutput{ + Contents: []types.Object{ + {Key: aws.String("prefix/file1.txt")}, + {Key: aws.String("prefix/file2.txt")}, + }, + IsTruncated: aws.Bool(true), + }, nil) + mockS3Client.On("ListObjects", mock.Anything, mock.MatchedBy(func(input *s3.ListObjectsInput) bool { + return *input.Marker == "prefix/file2.txt" + })).Return(&s3.ListObjectsOutput{ + Contents: []types.Object{ + {Key: aws.String("prefix/file3.txt")}, + }, + IsTruncated: aws.Bool(false), + }, nil) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3ListOptions{ + Path: "prefix/", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: false, + want: []string{"prefix/file1.txt", "prefix/file2.txt", "prefix/file3.txt"}, + }, + { + name: "error when list objects fails", + setup: func(t *testing.T) (*S3, S3ListOptions, func()) { + t.Helper() + + mockS3Client := mocks.NewMockS3APIClient(t) + mockS3Client. + On("ListObjects", mock.Anything, mock.Anything). + Return(&s3.ListObjectsOutput{}, ErrListObjects) + + return &S3{ + client: mockS3Client, + Bucket: "test-bucket", + }, S3ListOptions{ + Path: "prefix/", + }, func() { + mockS3Client.AssertExpectations(t) + } + }, + wantErr: true, + want: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s3, opt, teardown := tt.setup(t) + defer teardown() + + got, err := s3.List(context.Background(), opt) + if tt.wantErr { + assert.Error(t, err) + + return + } + + assert.NoError(t, err) + assert.ElementsMatch(t, tt.want, got) + }) + } +} diff --git a/go.mod b/go.mod index a2361f5..ce0e426 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,13 @@ module github.com/thegeeklab/wp-s3-action go 1.22 require ( - github.com/aws/aws-sdk-go v1.53.0 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/config v1.27.13 + github.com/aws/aws-sdk-go-v2/credentials v1.17.13 + github.com/aws/aws-sdk-go-v2/service/cloudfront v1.36.1 + github.com/aws/aws-sdk-go-v2/service/s3 v1.53.2 github.com/rs/zerolog v1.32.0 - github.com/ryanuber/go-glob v1.0.0 + github.com/stretchr/testify v1.9.0 github.com/thegeeklab/wp-plugin-go/v2 v2.3.1 github.com/urfave/cli/v2 v2.27.2 ) @@ -14,7 +18,22 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/Masterminds/sprig/v3 v3.2.3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.1.1 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.11 // indirect @@ -24,11 +43,14 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect github.com/mitchellh/copystructure v1.0.0 // indirect github.com/mitchellh/reflectwalk v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shopspring/decimal v1.2.0 // indirect github.com/spf13/cast v1.3.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect golang.org/x/crypto v0.23.0 // indirect golang.org/x/net v0.25.0 // indirect golang.org/x/sys v0.20.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8611d9c..d56b9cf 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,44 @@ github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0 github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= -github.com/aws/aws-sdk-go v1.53.0 h1:MMo1x1ggPPxDfHMXJnQudTbGXYlD4UigUAud1DJxPVo= -github.com/aws/aws-sdk-go v1.53.0/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/config v1.27.13 h1:WbKW8hOzrWoOA/+35S5okqO/2Ap8hkkFUzoW8Hzq24A= +github.com/aws/aws-sdk-go-v2/config v1.27.13/go.mod h1:XLiyiTMnguytjRER7u5RIkhIqS8Nyz41SwAWb4xEjxs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.13 h1:XDCJDzk/u5cN7Aple7D/MiAhx1Rjo/0nueJ0La8mRuE= +github.com/aws/aws-sdk-go-v2/credentials v1.17.13/go.mod h1:FMNcjQrmuBYvOTZDtOLCIu0esmxjF7RuA/89iSXWzQI= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5 h1:81KE7vaZzrl7yHBYHVEzYB8sypz11NMOZ40YlWvPxsU= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5/go.mod h1:LIt2rg7Mcgn09Ygbdh/RdIm0rQ+3BNkbP1gyVMFtRK0= +github.com/aws/aws-sdk-go-v2/service/cloudfront v1.36.1 h1://GRw/PrpnUyWBJh6KvUvR9AgkDBhclzaj3HKGxRoCw= +github.com/aws/aws-sdk-go-v2/service/cloudfront v1.36.1/go.mod h1:Pphkts8iBnexoEpcMti5fUvN3/yoGRLtl2heOeppF70= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7 h1:ZMeFZ5yk+Ek+jNr1+uwCd2tG89t6oTS5yVWpa6yy2es= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7/go.mod h1:mxV05U+4JiHqIpGqqYXOHLPKUC6bDXC44bsUhNjOEwY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5 h1:f9RyWNtS8oH7cZlbn+/JNPpjUk5+5fLd5lM9M0i49Ys= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5/go.mod h1:h5CoMZV2VF297/VLhRhO1WF+XYWOzXo+4HsObA4HjBQ= +github.com/aws/aws-sdk-go-v2/service/s3 v1.53.2 h1:rq2hglTQM3yHZvOPVMtNvLS5x6hijx7JvRDgKiTNDGQ= +github.com/aws/aws-sdk-go-v2/service/s3 v1.53.2/go.mod h1:qmdkIIAC+GCLASF7R2whgNrJADz0QZPX+Seiw/i4S3o= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 h1:o5cTaeunSpfXiLTIBx5xo2enQmiChtu1IBbzXnfU9Hs= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.0 h1:Qe0r0lVURDDeBQJ4yP+BOrJkvkiCo/3FH/t+wY11dmw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.0/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 h1:et3Ta53gotFR4ERLXXHIHl/Uuk1qYpP5uU7cvNql8ns= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -43,13 +79,13 @@ github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= -github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= @@ -93,12 +129,11 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/plugin/aws.go b/plugin/aws.go deleted file mode 100644 index 2681621..0000000 --- a/plugin/aws.go +++ /dev/null @@ -1,444 +0,0 @@ -package plugin - -import ( - //nolint:gosec - "crypto/md5" - "errors" - "fmt" - "io" - "mime" - "os" - "path/filepath" - "strings" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudfront" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/rs/zerolog/log" - "github.com/ryanuber/go-glob" -) - -type AWS struct { - client *s3.S3 - cfClient *cloudfront.CloudFront - remote []string - local []string - plugin *Plugin -} - -func NewAWS(plugin *Plugin) AWS { - sessCfg := &aws.Config{ - S3ForcePathStyle: aws.Bool(plugin.Settings.PathStyle), - Region: aws.String(plugin.Settings.Region), - } - - if plugin.Settings.Endpoint != "" { - sessCfg.Endpoint = &plugin.Settings.Endpoint - sessCfg.DisableSSL = aws.Bool(strings.HasPrefix(plugin.Settings.Endpoint, "http://")) - } - - // allowing to use the instance role or provide a key and secret - if plugin.Settings.AccessKey != "" && plugin.Settings.SecretKey != "" { - sessCfg.Credentials = credentials.NewStaticCredentials(plugin.Settings.AccessKey, plugin.Settings.SecretKey, "") - } - - sess, _ := session.NewSession(sessCfg) - - c := s3.New(sess) - cf := cloudfront.New(sess) - r := make([]string, 1) - l := make([]string, 1) - - return AWS{c, cf, r, l, plugin} -} - -//nolint:gocognit,gocyclo,maintidx -func (a *AWS) Upload(local, remote string) error { - plugin := a.plugin - - if local == "" { - return nil - } - - file, err := os.Open(local) - if err != nil { - return err - } - - defer file.Close() - - var acl string - - for pattern := range plugin.Settings.ACL { - if match := glob.Glob(pattern, local); match { - acl = plugin.Settings.ACL[pattern] - - break - } - } - - if acl == "" { - acl = "private" - } - - fileExt := filepath.Ext(local) - - var contentType string - - for patternExt := range plugin.Settings.ContentType { - if patternExt == fileExt { - contentType = plugin.Settings.ContentType[patternExt] - - break - } - } - - if contentType == "" { - contentType = mime.TypeByExtension(fileExt) - } - - var contentEncoding string - - for patternExt := range plugin.Settings.ContentEncoding { - if patternExt == fileExt { - contentEncoding = plugin.Settings.ContentEncoding[patternExt] - - break - } - } - - var cacheControl string - - for pattern := range plugin.Settings.CacheControl { - if match := glob.Glob(pattern, local); match { - cacheControl = plugin.Settings.CacheControl[pattern] - - break - } - } - - metadata := map[string]*string{} - - for pattern := range plugin.Settings.Metadata { - if match := glob.Glob(pattern, local); match { - for k, v := range plugin.Settings.Metadata[pattern] { - metadata[k] = aws.String(v) - } - - break - } - } - - var AWSErr awserr.Error - - head, err := a.client.HeadObject(&s3.HeadObjectInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(remote), - }) - if err != nil && errors.As(err, &AWSErr) { - //nolint:errorlint,forcetypeassert - if err.(awserr.Error).Code() == "404" { - return err - } - - log.Debug().Msgf( - "'%s' not found in bucket, uploading with content-type '%s' and permissions '%s'", - local, - contentType, - acl, - ) - - putObject := &s3.PutObjectInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(remote), - Body: file, - ContentType: aws.String(contentType), - ACL: aws.String(acl), - Metadata: metadata, - } - - if len(cacheControl) > 0 { - putObject.CacheControl = aws.String(cacheControl) - } - - if len(contentEncoding) > 0 { - putObject.ContentEncoding = aws.String(contentEncoding) - } - - // skip upload during dry run - if a.plugin.Settings.DryRun { - return nil - } - - _, err = a.client.PutObject(putObject) - - return err - } - - //nolint:gosec - hash := md5.New() - _, _ = io.Copy(hash, file) - sum := fmt.Sprintf("'%x'", hash.Sum(nil)) - - //nolint:nestif - if sum == *head.ETag { - shouldCopy := false - - if head.ContentType == nil && contentType != "" { - log.Debug().Msgf("content-type has changed from unset to %s", contentType) - - shouldCopy = true - } - - if !shouldCopy && head.ContentType != nil && contentType != *head.ContentType { - log.Debug().Msgf("content-type has changed from %s to %s", *head.ContentType, contentType) - - shouldCopy = true - } - - if !shouldCopy && head.ContentEncoding == nil && contentEncoding != "" { - log.Debug().Msgf("Content-Encoding has changed from unset to %s", contentEncoding) - - shouldCopy = true - } - - if !shouldCopy && head.ContentEncoding != nil && contentEncoding != *head.ContentEncoding { - log.Debug().Msgf("Content-Encoding has changed from %s to %s", *head.ContentEncoding, contentEncoding) - - shouldCopy = true - } - - if !shouldCopy && head.CacheControl == nil && cacheControl != "" { - log.Debug().Msgf("cache-control has changed from unset to %s", cacheControl) - - shouldCopy = true - } - - if !shouldCopy && head.CacheControl != nil && cacheControl != *head.CacheControl { - log.Debug().Msgf("cache-control has changed from %s to %s", *head.CacheControl, cacheControl) - - shouldCopy = true - } - - if !shouldCopy && len(head.Metadata) != len(metadata) { - log.Debug().Msgf("count of metadata values has changed for %s", local) - - shouldCopy = true - } - - if !shouldCopy && len(metadata) > 0 { - for k, v := range metadata { - if hv, ok := head.Metadata[k]; ok { - if *v != *hv { - log.Debug().Msgf("metadata values have changed for %s", local) - - shouldCopy = true - - break - } - } - } - } - - if !shouldCopy { - grant, err := a.client.GetObjectAcl(&s3.GetObjectAclInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(remote), - }) - if err != nil { - return err - } - - previousACL := "private" - - for _, grant := range grant.Grants { - grantee := *grant.Grantee - if grantee.URI != nil { - if *grantee.URI == "http://acs.amazonaws.com/groups/global/AllUsers" { - if *grant.Permission == "READ" { - previousACL = "public-read" - } else if *grant.Permission == "WRITE" { - previousACL = "public-read-write" - } - } - - if *grantee.URI == "http://acs.amazonaws.com/groups/global/AuthenticatedUsers" { - if *grant.Permission == "READ" { - previousACL = "authenticated-read" - } - } - } - } - - if previousACL != acl { - log.Debug().Msgf("permissions for '%s' have changed from '%s' to '%s'", remote, previousACL, acl) - - shouldCopy = true - } - } - - if !shouldCopy { - log.Debug().Msgf("skipping '%s' because hashes and metadata match", local) - - return nil - } - - log.Debug().Msgf("updating metadata for '%s' content-type: '%s', ACL: '%s'", local, contentType, acl) - - copyObject := &s3.CopyObjectInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(remote), - CopySource: aws.String(fmt.Sprintf("%s/%s", plugin.Settings.Bucket, remote)), - ACL: aws.String(acl), - ContentType: aws.String(contentType), - Metadata: metadata, - MetadataDirective: aws.String("REPLACE"), - } - - if len(cacheControl) > 0 { - copyObject.CacheControl = aws.String(cacheControl) - } - - if len(contentEncoding) > 0 { - copyObject.ContentEncoding = aws.String(contentEncoding) - } - - // skip update if dry run - if a.plugin.Settings.DryRun { - return nil - } - - _, err = a.client.CopyObject(copyObject) - - return err - } - - _, err = file.Seek(0, 0) - if err != nil { - return err - } - - log.Debug().Msgf("uploading '%s' with content-type '%s' and permissions '%s'", local, contentType, acl) - - putObject := &s3.PutObjectInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(remote), - Body: file, - ContentType: aws.String(contentType), - ACL: aws.String(acl), - Metadata: metadata, - } - - if len(cacheControl) > 0 { - putObject.CacheControl = aws.String(cacheControl) - } - - if len(contentEncoding) > 0 { - putObject.ContentEncoding = aws.String(contentEncoding) - } - - // skip upload if dry run - if a.plugin.Settings.DryRun { - return nil - } - - _, err = a.client.PutObject(putObject) - - return err -} - -func (a *AWS) Redirect(path, location string) error { - plugin := a.plugin - - log.Debug().Msgf("adding redirect from '%s' to '%s'", path, location) - - if a.plugin.Settings.DryRun { - return nil - } - - _, err := a.client.PutObject(&s3.PutObjectInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(path), - ACL: aws.String("public-read"), - WebsiteRedirectLocation: aws.String(location), - }) - - return err -} - -func (a *AWS) Delete(remote string) error { - plugin := a.plugin - - log.Debug().Msgf("removing remote file '%s'", remote) - - if a.plugin.Settings.DryRun { - return nil - } - - _, err := a.client.DeleteObject(&s3.DeleteObjectInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Key: aws.String(remote), - }) - - return err -} - -func (a *AWS) List(path string) ([]string, error) { - plugin := a.plugin - - remote := make([]string, 0) - - resp, err := a.client.ListObjects(&s3.ListObjectsInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Prefix: aws.String(path), - }) - if err != nil { - return remote, err - } - - for _, item := range resp.Contents { - remote = append(remote, *item.Key) - } - - for *resp.IsTruncated { - resp, err = a.client.ListObjects(&s3.ListObjectsInput{ - Bucket: aws.String(plugin.Settings.Bucket), - Prefix: aws.String(path), - Marker: aws.String(remote[len(remote)-1]), - }) - if err != nil { - return remote, err - } - - for _, item := range resp.Contents { - remote = append(remote, *item.Key) - } - } - - return remote, nil -} - -func (a *AWS) Invalidate(invalidatePath string) error { - p := a.plugin - - log.Debug().Msgf("invalidating '%s'", invalidatePath) - - _, err := a.cfClient.CreateInvalidation(&cloudfront.CreateInvalidationInput{ - DistributionId: aws.String(p.Settings.CloudFrontDistribution), - InvalidationBatch: &cloudfront.InvalidationBatch{ - CallerReference: aws.String(time.Now().Format(time.RFC3339Nano)), - Paths: &cloudfront.Paths{ - Quantity: aws.Int64(1), - Items: []*string{ - aws.String(invalidatePath), - }, - }, - }, - }) - - return err -} diff --git a/plugin/impl.go b/plugin/impl.go index dc42f47..2620e07 100644 --- a/plugin/impl.go +++ b/plugin/impl.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/rs/zerolog/log" + "github.com/thegeeklab/wp-s3-action/aws" ) var ErrTypeAssertionFailed = errors.New("type assertion failed") @@ -44,9 +45,25 @@ func (p *Plugin) Validate() error { // Execute provides the implementation of the plugin. func (p *Plugin) Execute() error { p.Settings.Jobs = make([]Job, 1) - p.Settings.Client = NewAWS(p) - if err := p.createSyncJobs(); err != nil { + client, err := aws.NewClient( + p.Network.Context, + p.Settings.Endpoint, + p.Settings.Region, + p.Settings.AccessKey, + p.Settings.SecretKey, + p.Settings.PathStyle, + ) + if err != nil { + return fmt.Errorf("error while creating AWS client: %w", err) + } + + client.S3.Bucket = p.Settings.Bucket + client.S3.DryRun = p.Settings.DryRun + + client.Cloudfront.Distribution = p.Settings.CloudFrontDistribution + + if err := p.createSyncJobs(p.Network.Context, client); err != nil { return fmt.Errorf("error while creating sync job: %w", err) } @@ -58,15 +75,15 @@ func (p *Plugin) Execute() error { }) } - if err := p.runJobs(); err != nil { + if err := p.runJobs(p.Network.Context, client); err != nil { return fmt.Errorf("error while creating sync job: %w", err) } return nil } -func (p *Plugin) createSyncJobs() error { - remote, err := p.Settings.Client.List(p.Settings.Target) +func (p *Plugin) createSyncJobs(ctx context.Context, client *aws.Client) error { + remote, err := client.S3.List(ctx, aws.S3ListOptions{Path: p.Settings.Target}) if err != nil { return err } @@ -134,8 +151,7 @@ func (p *Plugin) createSyncJobs() error { return nil } -func (p *Plugin) runJobs() error { - client := p.Settings.Client +func (p *Plugin) runJobs(ctx context.Context, client *aws.Client) error { jobChan := make(chan struct{}, p.Settings.MaxConcurrency) results := make(chan *Result, len(p.Settings.Jobs)) @@ -151,11 +167,27 @@ func (p *Plugin) runJobs() error { switch job.action { case "upload": - err = client.Upload(job.local, job.remote) + opt := aws.S3UploadOptions{ + LocalFilePath: job.local, + RemoteObjectKey: job.remote, + ACL: p.Settings.ACL, + ContentType: p.Settings.ContentType, + ContentEncoding: p.Settings.ContentEncoding, + CacheControl: p.Settings.CacheControl, + Metadata: p.Settings.Metadata, + } + err = client.S3.Upload(ctx, opt) case "redirect": - err = client.Redirect(job.local, job.remote) + opt := aws.S3RedirectOptions{ + Path: job.local, + Location: job.remote, + } + err = client.S3.Redirect(ctx, opt) case "delete": - err = client.Delete(job.remote) + opt := aws.S3DeleteOptions{ + RemoteObjectKey: job.remote, + } + err = client.S3.Delete(ctx, opt) case "invalidateCloudFront": invalidateJob = &job default: @@ -175,7 +207,11 @@ func (p *Plugin) runJobs() error { } if invalidateJob != nil { - err := client.Invalidate(invalidateJob.remote) + opt := aws.CloudfrontInvalidateOptions{ + Path: invalidateJob.remote, + } + + err := client.Cloudfront.Invalidate(ctx, opt) if err != nil { return fmt.Errorf("failed to %s %s to %s: %w", invalidateJob.action, invalidateJob.local, invalidateJob.remote, err) } diff --git a/plugin/plugin.go b/plugin/plugin.go index cb60b41..f0caec1 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -35,7 +35,6 @@ type Settings struct { CloudFrontDistribution string DryRun bool PathStyle bool - Client AWS Jobs []Job MaxConcurrency int }