diff --git a/git/config.go b/git/config.go index 2cbd170..b4ebc6d 100644 --- a/git/config.go +++ b/git/config.go @@ -75,3 +75,17 @@ func (r *Repository) ConfigSSLVerify(skipVerify bool) *types.Cmd { Cmd: cmd, } } + +// ConfigSSHCommand sets custom SSH key. +func (r *Repository) ConfigSSHCommand(sshKey string) *types.Cmd { + args := []string{ + "config", + "--local", + "core.sshCommand", + "ssh -i " + sshKey, + } + + return &types.Cmd{ + Cmd: execabs.Command(gitBin, args...), + } +} diff --git a/git/config_test.go b/git/config_test.go index c47a5cf..12e63b0 100644 --- a/git/config_test.go +++ b/git/config_test.go @@ -122,3 +122,26 @@ func TestConfigSSLVerify(t *testing.T) { }) } } + +func TestConfigSSHCommand(t *testing.T) { + tests := []struct { + name string + repo Repository + sshKey string + want []string + }{ + { + name: "set SSH command with key", + repo: Repository{}, + sshKey: "/path/to/ssh/key", + want: []string{gitBin, "config", "--local", "core.sshCommand", "ssh -i /path/to/ssh/key"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := tt.repo.ConfigSSHCommand(tt.sshKey) + assert.Equal(t, tt.want, cmd.Cmd.Args) + }) + } +} diff --git a/plugin/impl.go b/plugin/impl.go index fcd7cbd..df2c72c 100644 --- a/plugin/impl.go +++ b/plugin/impl.go @@ -125,9 +125,7 @@ func (p *Plugin) Execute() error { // Write SSH key and netrc file. if p.Settings.SSHKey != "" { - if err := WriteSSHKey(homeDir, p.Settings.SSHKey); err != nil { - return err - } + batchCmd = append(batchCmd, p.Settings.Repo.ConfigSSHCommand(p.Settings.SSHKey)) } netrc := p.Settings.Netrc diff --git a/plugin/util.go b/plugin/util.go index 143beca..10e8ab5 100644 --- a/plugin/util.go +++ b/plugin/util.go @@ -10,38 +10,10 @@ const ( netrcFile = `machine %s login %s password %s -` - configFile = `Host * -StrictHostKeyChecking no -UserKnownHostsFile=/dev/null ` ) -const ( - strictFilePerm = 0o600 - strictDirPerm = 0o700 -) - -// WriteKey writes the SSH private key. -func WriteSSHKey(path, key string) error { - sshPath := filepath.Join(path, ".ssh") - confPath := filepath.Join(sshPath, "config") - keyPath := filepath.Join(sshPath, "id_rsa") - - if err := os.MkdirAll(sshPath, strictDirPerm); err != nil { - return fmt.Errorf("failed to create .ssh directory: %w", err) - } - - if err := os.WriteFile(confPath, []byte(configFile), strictFilePerm); err != nil { - return fmt.Errorf("failed to create .ssh/config file: %w", err) - } - - if err := os.WriteFile(keyPath, []byte(key), strictFilePerm); err != nil { - return fmt.Errorf("failed to create .ssh/id_rsa file: %w", err) - } - - return nil -} +const strictFilePerm = 0o600 // WriteNetrc writes the netrc file. func WriteNetrc(path, machine, login, password string) error { diff --git a/plugin/util_test.go b/plugin/util_test.go index bfa510a..59eac2d 100644 --- a/plugin/util_test.go +++ b/plugin/util_test.go @@ -9,49 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestWriteSSHKey(t *testing.T) { - tests := []struct { - name string - privateKey string - dir string - wantErr bool - }{ - { - name: "valid private key", - privateKey: "valid_private_key", - dir: t.TempDir(), - wantErr: false, - }, - { - name: "empty private key", - privateKey: "", - dir: t.TempDir(), - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := WriteSSHKey(tt.dir, tt.privateKey) - if tt.wantErr { - assert.Error(t, err) - - return - } - - assert.NoError(t, err) - - privateKeyPath := filepath.Join(tt.dir, ".ssh", "id_rsa") - _, err = os.Stat(privateKeyPath) - assert.NoError(t, err) - - configPath := filepath.Join(tt.dir, ".ssh", "config") - _, err = os.Stat(configPath) - assert.NoError(t, err) - }) - } -} - func TestWriteNetrc(t *testing.T) { tests := []struct { name string