0
0
mirror of https://github.com/thegeeklab/wp-s3-action.git synced 2024-11-09 18:30:40 +00:00

refactor for concurrency

This commit is contained in:
Nathan LaFreniere 2015-12-19 16:15:04 -08:00
parent 07033aa1a0
commit 0d6aff17dc
2 changed files with 145 additions and 113 deletions

150
aws.go
View File

@ -7,7 +7,6 @@ import (
"mime" "mime"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
@ -35,28 +34,12 @@ func NewAWS(vargs PluginArgs) AWS {
return AWS{c, r, l, vargs} return AWS{c, r, l, vargs}
} }
func (a *AWS) visit(path string, info os.FileInfo, err error) error { func (a *AWS) Upload(local, remote string) error {
if err != nil { if local == "" {
return err
}
if path == "." {
return nil return nil
} }
if info.IsDir() { file, err := os.Open(local)
return nil
}
localPath := strings.TrimPrefix(path, a.vargs.Source)
if strings.HasPrefix(localPath, "/") {
localPath = localPath[1:]
}
remotePath := filepath.Join(a.vargs.Target, localPath)
a.local = append(a.local, localPath)
file, err := os.Open(path)
if err != nil { if err != nil {
return err return err
} }
@ -69,7 +52,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
} else if !a.vargs.Access.IsEmpty() { } else if !a.vargs.Access.IsEmpty() {
accessMap := a.vargs.Access.Map() accessMap := a.vargs.Access.Map()
for pattern := range accessMap { for pattern := range accessMap {
if match := glob.Glob(pattern, localPath); match == true { if match := glob.Glob(pattern, local); match == true {
access = accessMap[pattern] access = accessMap[pattern]
break break
} }
@ -80,7 +63,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
access = "private" access = "private"
} }
fileExt := filepath.Ext(localPath) fileExt := filepath.Ext(local)
var contentType string var contentType string
if a.vargs.ContentType.IsString() { if a.vargs.ContentType.IsString() {
contentType = a.vargs.ContentType.String() contentType = a.vargs.ContentType.String()
@ -98,7 +81,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
vmap := a.vargs.Metadata.Map() vmap := a.vargs.Metadata.Map()
if len(vmap) > 0 { if len(vmap) > 0 {
for pattern := range vmap { for pattern := range vmap {
if match := glob.Glob(pattern, localPath); match == true { if match := glob.Glob(pattern, local); match == true {
for k, v := range vmap[pattern] { for k, v := range vmap[pattern] {
metadata[k] = aws.String(v) metadata[k] = aws.String(v)
} }
@ -111,42 +94,34 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
contentType = mime.TypeByExtension(fileExt) contentType = mime.TypeByExtension(fileExt)
} }
exists := false head, err := a.client.HeadObject(&s3.HeadObjectInput{
for _, remoteFile := range a.remote { Bucket: aws.String(a.vargs.Bucket),
if remoteFile == localPath { Key: aws.String(remote),
exists = true })
break if err != nil {
} return err
} }
if exists { if head != nil {
hash := md5.New() hash := md5.New()
io.Copy(hash, file) io.Copy(hash, file)
sum := fmt.Sprintf("\"%x\"", hash.Sum(nil)) sum := fmt.Sprintf("\"%x\"", hash.Sum(nil))
head, err := a.client.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath),
})
if err != nil {
return err
}
if sum == *head.ETag { if sum == *head.ETag {
shouldCopy := false shouldCopy := false
if head.ContentType == nil && contentType != "" { if head.ContentType == nil && contentType != "" {
debug("Content-Type has changed from unset to %s\n", contentType) debug("Content-Type has changed from unset to %s", contentType)
shouldCopy = true shouldCopy = true
} }
if !shouldCopy && head.ContentType != nil && contentType != *head.ContentType { if !shouldCopy && head.ContentType != nil && contentType != *head.ContentType {
debug("Content-Type has changed from %s to %s\n", *head.ContentType, contentType) debug("Content-Type has changed from %s to %s", *head.ContentType, contentType)
shouldCopy = true shouldCopy = true
} }
if !shouldCopy && len(head.Metadata) != len(metadata) { if !shouldCopy && len(head.Metadata) != len(metadata) {
debug("Count of metadata values has changed for %s\n", localPath) debug("Count of metadata values has changed for %s", local)
shouldCopy = true shouldCopy = true
} }
@ -154,7 +129,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
for k, v := range metadata { for k, v := range metadata {
if hv, ok := head.Metadata[k]; ok { if hv, ok := head.Metadata[k]; ok {
if *v != *hv { if *v != *hv {
debug("Metadata values have changed for %s\n", localPath) debug("Metadata values have changed for %s", local)
shouldCopy = true shouldCopy = true
break break
} }
@ -165,7 +140,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
if !shouldCopy { if !shouldCopy {
grant, err := a.client.GetObjectAcl(&s3.GetObjectAclInput{ grant, err := a.client.GetObjectAcl(&s3.GetObjectAclInput{
Bucket: aws.String(a.vargs.Bucket), Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath), Key: aws.String(remote),
}) })
if err != nil { if err != nil {
return err return err
@ -190,21 +165,21 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
} }
if previousAccess != access { if previousAccess != access {
debug("Permissions for \"%s\" have changed from \"%s\" to \"%s\"\n", remotePath, previousAccess, access) debug("Permissions for \"%s\" have changed from \"%s\" to \"%s\"", remote, previousAccess, access)
shouldCopy = true shouldCopy = true
} }
} }
if !shouldCopy { if !shouldCopy {
debug("Skipping \"%s\" because hashes and metadata match\n", localPath) debug("Skipping \"%s\" because hashes and metadata match", local)
return nil return nil
} }
fmt.Printf("Updating metadata for \"%s\" Content-Type: \"%s\", ACL: \"%s\"\n", localPath, contentType, access) debug("Updating metadata for \"%s\" Content-Type: \"%s\", ACL: \"%s\"", local, contentType, access)
_, err = a.client.CopyObject(&s3.CopyObjectInput{ _, err = a.client.CopyObject(&s3.CopyObjectInput{
Bucket: aws.String(a.vargs.Bucket), Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath), Key: aws.String(remote),
CopySource: aws.String(fmt.Sprintf("%s/%s", a.vargs.Bucket, remotePath)), CopySource: aws.String(fmt.Sprintf("%s/%s", a.vargs.Bucket, remote)),
ACL: aws.String(access), ACL: aws.String(access),
ContentType: aws.String(contentType), ContentType: aws.String(contentType),
Metadata: metadata, Metadata: metadata,
@ -219,10 +194,10 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
} }
} }
fmt.Printf("Uploading \"%s\" with Content-Type \"%s\" and permissions \"%s\"\n", localPath, contentType, access) debug("Uploading \"%s\" with Content-Type \"%s\" and permissions \"%s\"", local, contentType, access)
_, err = a.client.PutObject(&s3.PutObjectInput{ _, err = a.client.PutObject(&s3.PutObjectInput{
Bucket: aws.String(a.vargs.Bucket), Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath), Key: aws.String(remote),
Body: file, Body: file,
ContentType: aws.String(contentType), ContentType: aws.String(contentType),
ACL: aws.String(access), ACL: aws.String(access),
@ -231,78 +206,55 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
return err return err
} }
func (a *AWS) AddRedirects(redirects map[string]string) error { func (a *AWS) Redirect(path, location string) error {
for path, location := range redirects { debug("Adding redirect from \"%s\" to \"%s\"\n", path, location)
fmt.Printf("Adding redirect from \"%s\" to \"%s\"", path, location) _, err := a.client.PutObject(&s3.PutObjectInput{
a.local = append(a.local, strings.TrimPrefix(path, "/")) Bucket: aws.String(a.vargs.Bucket),
_, err := a.client.PutObject(&s3.PutObjectInput{ Key: aws.String(path),
Bucket: aws.String(a.vargs.Bucket), ACL: aws.String("public-read"),
Key: aws.String(path), WebsiteRedirectLocation: aws.String(location),
ACL: aws.String("public-read"), })
WebsiteRedirectLocation: aws.String(location), return err
})
if err != nil {
return err
}
}
return nil
} }
func (a *AWS) List(path string) error { func (a *AWS) Delete(remote string) error {
debug("Removing remote file \"%s\"\n", remote)
_, err := a.client.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remote),
})
return err
}
func (a *AWS) List(path string) ([]string, error) {
remote := make([]string, 1, 1)
resp, err := a.client.ListObjects(&s3.ListObjectsInput{ resp, err := a.client.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(a.vargs.Bucket), Bucket: aws.String(a.vargs.Bucket),
Prefix: aws.String(path), Prefix: aws.String(path),
}) })
if err != nil { if err != nil {
return err return remote, err
} }
for _, item := range resp.Contents { for _, item := range resp.Contents {
a.remote = append(a.remote, *item.Key) remote = append(remote, *item.Key)
} }
for *resp.IsTruncated { for *resp.IsTruncated {
resp, err = a.client.ListObjects(&s3.ListObjectsInput{ resp, err = a.client.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(a.vargs.Bucket), Bucket: aws.String(a.vargs.Bucket),
Prefix: aws.String(path), Prefix: aws.String(path),
Marker: aws.String(a.remote[len(a.remote)-1]), Marker: aws.String(remote[len(remote)-1]),
}) })
if err != nil { if err != nil {
return err return remote, err
} }
for _, item := range resp.Contents { for _, item := range resp.Contents {
a.remote = append(a.remote, *item.Key) remote = append(remote, *item.Key)
} }
} }
return nil return remote, nil
}
func (a *AWS) Cleanup() error {
for _, remote := range a.remote {
found := false
for _, local := range a.local {
if local == remote {
found = true
break
}
}
if !found {
fmt.Printf("Removing remote file \"%s\"\n", remote)
_, err := a.client.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remote),
})
if err != nil {
return err
}
}
}
return nil
} }

108
main.go
View File

@ -10,6 +10,19 @@ import (
"github.com/drone/drone-go/plugin" "github.com/drone/drone-go/plugin"
) )
const maxConcurrent = 100
type job struct {
local string
remote string
action string
}
type result struct {
j job
err error
}
func main() { func main() {
vargs := PluginArgs{} vargs := PluginArgs{}
workspace := drone.Workspace{} workspace := drone.Workspace{}
@ -39,37 +52,104 @@ func main() {
} }
client := NewAWS(vargs) client := NewAWS(vargs)
err := client.List(vargs.Target) remote, err := client.List(vargs.Target)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
err = filepath.Walk(vargs.Source, client.visit) local := make([]string, 1, 1)
jobs := make([]job, 1, 1)
err = filepath.Walk(vargs.Source, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() {
return err
}
localPath := path
if vargs.Source != "." {
localPath = strings.TrimPrefix(path, vargs.Source)
if strings.HasPrefix(localPath, "/") {
localPath = localPath[1:]
}
}
local = append(local, localPath)
jobs = append(jobs, job{
local: filepath.Join(vargs.Source, localPath),
remote: filepath.Join(vargs.Target, localPath),
action: "upload",
})
return nil
})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
if len(vargs.Redirects) > 0 { for path, location := range vargs.Redirects {
err = client.AddRedirects(vargs.Redirects) path = strings.TrimPrefix(path, "/")
if err != nil { local = append(local, path)
fmt.Println(err) jobs = append(jobs, job{
local: path,
remote: location,
action: "redirect",
})
}
for _, r := range remote {
found := false
for _, l := range local {
if l == r {
found = true
break
}
}
if !found {
jobs = append(jobs, job{
local: "",
remote: r,
action: "delete",
})
}
}
jobChan := make(chan struct{}, maxConcurrent)
results := make(chan *result, len(jobs))
fmt.Printf("Synchronizing with bucket \"%s\"", vargs.Bucket)
for _, j := range jobs {
jobChan <- struct{}{}
go func(j job) {
if j.action == "upload" {
err = client.Upload(j.local, j.remote)
} else if j.action == "redirect" {
err = client.Redirect(j.local, j.remote)
} else if j.action == "delete" && vargs.Delete {
err = client.Delete(j.remote)
} else {
err = nil
}
results <- &result{j, err}
<-jobChan
}(j)
}
for _ = range jobs {
r := <-results
if r.err != nil {
fmt.Printf("ERROR: failed to %s %s to %s: %s\n", r.j.action, r.j.local, r.j.remote, r.err)
os.Exit(1) os.Exit(1)
} }
} }
if vargs.Delete { fmt.Println("done!")
err = client.Cleanup()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
} }
func debug(format string, args ...interface{}) { func debug(format string, args ...interface{}) {
if os.Getenv("DEBUG") != "" { if os.Getenv("DEBUG") != "" {
fmt.Printf(format, args...) fmt.Printf(format+"\n", args...)
} else {
fmt.Printf(".")
} }
} }