0
0
mirror of https://github.com/thegeeklab/wp-s3-action.git synced 2024-11-21 14:50:39 +00:00

refactor for concurrency

This commit is contained in:
Nathan LaFreniere 2015-12-19 16:15:04 -08:00
parent 07033aa1a0
commit 0d6aff17dc
2 changed files with 145 additions and 113 deletions

150
aws.go
View File

@ -7,7 +7,6 @@ import (
"mime"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
@ -35,28 +34,12 @@ func NewAWS(vargs PluginArgs) AWS {
return AWS{c, r, l, vargs}
}
func (a *AWS) visit(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == "." {
func (a *AWS) Upload(local, remote string) error {
if local == "" {
return nil
}
if info.IsDir() {
return nil
}
localPath := strings.TrimPrefix(path, a.vargs.Source)
if strings.HasPrefix(localPath, "/") {
localPath = localPath[1:]
}
remotePath := filepath.Join(a.vargs.Target, localPath)
a.local = append(a.local, localPath)
file, err := os.Open(path)
file, err := os.Open(local)
if err != nil {
return err
}
@ -69,7 +52,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
} else if !a.vargs.Access.IsEmpty() {
accessMap := a.vargs.Access.Map()
for pattern := range accessMap {
if match := glob.Glob(pattern, localPath); match == true {
if match := glob.Glob(pattern, local); match == true {
access = accessMap[pattern]
break
}
@ -80,7 +63,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
access = "private"
}
fileExt := filepath.Ext(localPath)
fileExt := filepath.Ext(local)
var contentType string
if a.vargs.ContentType.IsString() {
contentType = a.vargs.ContentType.String()
@ -98,7 +81,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
vmap := a.vargs.Metadata.Map()
if len(vmap) > 0 {
for pattern := range vmap {
if match := glob.Glob(pattern, localPath); match == true {
if match := glob.Glob(pattern, local); match == true {
for k, v := range vmap[pattern] {
metadata[k] = aws.String(v)
}
@ -111,42 +94,34 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
contentType = mime.TypeByExtension(fileExt)
}
exists := false
for _, remoteFile := range a.remote {
if remoteFile == localPath {
exists = true
break
}
head, err := a.client.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remote),
})
if err != nil {
return err
}
if exists {
if head != nil {
hash := md5.New()
io.Copy(hash, file)
sum := fmt.Sprintf("\"%x\"", hash.Sum(nil))
head, err := a.client.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath),
})
if err != nil {
return err
}
if sum == *head.ETag {
shouldCopy := false
if head.ContentType == nil && contentType != "" {
debug("Content-Type has changed from unset to %s\n", contentType)
debug("Content-Type has changed from unset to %s", contentType)
shouldCopy = true
}
if !shouldCopy && head.ContentType != nil && contentType != *head.ContentType {
debug("Content-Type has changed from %s to %s\n", *head.ContentType, contentType)
debug("Content-Type has changed from %s to %s", *head.ContentType, contentType)
shouldCopy = true
}
if !shouldCopy && len(head.Metadata) != len(metadata) {
debug("Count of metadata values has changed for %s\n", localPath)
debug("Count of metadata values has changed for %s", local)
shouldCopy = true
}
@ -154,7 +129,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
for k, v := range metadata {
if hv, ok := head.Metadata[k]; ok {
if *v != *hv {
debug("Metadata values have changed for %s\n", localPath)
debug("Metadata values have changed for %s", local)
shouldCopy = true
break
}
@ -165,7 +140,7 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
if !shouldCopy {
grant, err := a.client.GetObjectAcl(&s3.GetObjectAclInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath),
Key: aws.String(remote),
})
if err != nil {
return err
@ -190,21 +165,21 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
}
if previousAccess != access {
debug("Permissions for \"%s\" have changed from \"%s\" to \"%s\"\n", remotePath, previousAccess, access)
debug("Permissions for \"%s\" have changed from \"%s\" to \"%s\"", remote, previousAccess, access)
shouldCopy = true
}
}
if !shouldCopy {
debug("Skipping \"%s\" because hashes and metadata match\n", localPath)
debug("Skipping \"%s\" because hashes and metadata match", local)
return nil
}
fmt.Printf("Updating metadata for \"%s\" Content-Type: \"%s\", ACL: \"%s\"\n", localPath, contentType, access)
debug("Updating metadata for \"%s\" Content-Type: \"%s\", ACL: \"%s\"", local, contentType, access)
_, err = a.client.CopyObject(&s3.CopyObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath),
CopySource: aws.String(fmt.Sprintf("%s/%s", a.vargs.Bucket, remotePath)),
Key: aws.String(remote),
CopySource: aws.String(fmt.Sprintf("%s/%s", a.vargs.Bucket, remote)),
ACL: aws.String(access),
ContentType: aws.String(contentType),
Metadata: metadata,
@ -219,10 +194,10 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
}
}
fmt.Printf("Uploading \"%s\" with Content-Type \"%s\" and permissions \"%s\"\n", localPath, contentType, access)
debug("Uploading \"%s\" with Content-Type \"%s\" and permissions \"%s\"", local, contentType, access)
_, err = a.client.PutObject(&s3.PutObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remotePath),
Key: aws.String(remote),
Body: file,
ContentType: aws.String(contentType),
ACL: aws.String(access),
@ -231,78 +206,55 @@ func (a *AWS) visit(path string, info os.FileInfo, err error) error {
return err
}
func (a *AWS) AddRedirects(redirects map[string]string) error {
for path, location := range redirects {
fmt.Printf("Adding redirect from \"%s\" to \"%s\"", path, location)
a.local = append(a.local, strings.TrimPrefix(path, "/"))
_, err := a.client.PutObject(&s3.PutObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(path),
ACL: aws.String("public-read"),
WebsiteRedirectLocation: aws.String(location),
})
if err != nil {
return err
}
}
return nil
func (a *AWS) Redirect(path, location string) error {
debug("Adding redirect from \"%s\" to \"%s\"\n", path, location)
_, err := a.client.PutObject(&s3.PutObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(path),
ACL: aws.String("public-read"),
WebsiteRedirectLocation: aws.String(location),
})
return err
}
func (a *AWS) List(path string) error {
func (a *AWS) Delete(remote string) error {
debug("Removing remote file \"%s\"\n", remote)
_, err := a.client.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remote),
})
return err
}
func (a *AWS) List(path string) ([]string, error) {
remote := make([]string, 1, 1)
resp, err := a.client.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(a.vargs.Bucket),
Prefix: aws.String(path),
})
if err != nil {
return err
return remote, err
}
for _, item := range resp.Contents {
a.remote = append(a.remote, *item.Key)
remote = append(remote, *item.Key)
}
for *resp.IsTruncated {
resp, err = a.client.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(a.vargs.Bucket),
Prefix: aws.String(path),
Marker: aws.String(a.remote[len(a.remote)-1]),
Marker: aws.String(remote[len(remote)-1]),
})
if err != nil {
return err
return remote, err
}
for _, item := range resp.Contents {
a.remote = append(a.remote, *item.Key)
remote = append(remote, *item.Key)
}
}
return nil
}
func (a *AWS) Cleanup() error {
for _, remote := range a.remote {
found := false
for _, local := range a.local {
if local == remote {
found = true
break
}
}
if !found {
fmt.Printf("Removing remote file \"%s\"\n", remote)
_, err := a.client.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(a.vargs.Bucket),
Key: aws.String(remote),
})
if err != nil {
return err
}
}
}
return nil
return remote, nil
}

108
main.go
View File

@ -10,6 +10,19 @@ import (
"github.com/drone/drone-go/plugin"
)
const maxConcurrent = 100
type job struct {
local string
remote string
action string
}
type result struct {
j job
err error
}
func main() {
vargs := PluginArgs{}
workspace := drone.Workspace{}
@ -39,37 +52,104 @@ func main() {
}
client := NewAWS(vargs)
err := client.List(vargs.Target)
remote, err := client.List(vargs.Target)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
err = filepath.Walk(vargs.Source, client.visit)
local := make([]string, 1, 1)
jobs := make([]job, 1, 1)
err = filepath.Walk(vargs.Source, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() {
return err
}
localPath := path
if vargs.Source != "." {
localPath = strings.TrimPrefix(path, vargs.Source)
if strings.HasPrefix(localPath, "/") {
localPath = localPath[1:]
}
}
local = append(local, localPath)
jobs = append(jobs, job{
local: filepath.Join(vargs.Source, localPath),
remote: filepath.Join(vargs.Target, localPath),
action: "upload",
})
return nil
})
if err != nil {
fmt.Println(err)
os.Exit(1)
}
if len(vargs.Redirects) > 0 {
err = client.AddRedirects(vargs.Redirects)
if err != nil {
fmt.Println(err)
for path, location := range vargs.Redirects {
path = strings.TrimPrefix(path, "/")
local = append(local, path)
jobs = append(jobs, job{
local: path,
remote: location,
action: "redirect",
})
}
for _, r := range remote {
found := false
for _, l := range local {
if l == r {
found = true
break
}
}
if !found {
jobs = append(jobs, job{
local: "",
remote: r,
action: "delete",
})
}
}
jobChan := make(chan struct{}, maxConcurrent)
results := make(chan *result, len(jobs))
fmt.Printf("Synchronizing with bucket \"%s\"", vargs.Bucket)
for _, j := range jobs {
jobChan <- struct{}{}
go func(j job) {
if j.action == "upload" {
err = client.Upload(j.local, j.remote)
} else if j.action == "redirect" {
err = client.Redirect(j.local, j.remote)
} else if j.action == "delete" && vargs.Delete {
err = client.Delete(j.remote)
} else {
err = nil
}
results <- &result{j, err}
<-jobChan
}(j)
}
for _ = range jobs {
r := <-results
if r.err != nil {
fmt.Printf("ERROR: failed to %s %s to %s: %s\n", r.j.action, r.j.local, r.j.remote, r.err)
os.Exit(1)
}
}
if vargs.Delete {
err = client.Cleanup()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
fmt.Println("done!")
}
func debug(format string, args ...interface{}) {
if os.Getenv("DEBUG") != "" {
fmt.Printf(format, args...)
fmt.Printf(format+"\n", args...)
} else {
fmt.Printf(".")
}
}