diff --git a/cmd/drone-ecr/main.go b/cmd/drone-ecr/main.go index 0df3fac..0fda28d 100644 --- a/cmd/drone-ecr/main.go +++ b/cmd/drone-ecr/main.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecr" ) @@ -27,6 +28,7 @@ func main() { create = parseBoolOrDefault(false, getenv("PLUGIN_CREATE_REPOSITORY", "ECR_CREATE_REPOSITORY")) lifecyclePolicy = getenv("PLUGIN_LIFECYCLE_POLICY") repositoryPolicy = getenv("PLUGIN_REPOSITORY_POLICY") + assumeRole = getenv("PLUGIN_ASSUME_ROLE") ) // set the region @@ -42,12 +44,12 @@ func main() { } sess, err := session.NewSession(&aws.Config{Region: ®ion}) - + if err != nil { log.Fatal(fmt.Sprintf("error creating aws session: %v", err)) } - svc := ecr.New(sess) + svc := getECRClient(sess, assumeRole) username, password, registry, err := getAuthInfo(svc) if err != nil { log.Fatal(fmt.Sprintf("error getting ECR auth: %v", err)) @@ -178,3 +180,11 @@ func getenv(key ...string) (s string) { } return } + +func getECRClient(sess *session.Session, role string) *ecr.ECR { + if role == "" { + return ecr.New(sess) + } + creds := stscreds.NewCredentials(sess, role) + return ecr.New(sess, &aws.Config{Credentials: creds}) +}