summaryrefslogtreecommitdiff
path: root/remote_test.go
diff options
context:
space:
mode:
authorlhchavez <[email protected]>2020-12-06 11:55:04 -0800
committerGitHub <[email protected]>2020-12-06 11:55:04 -0800
commitabf02bc7d79dfb7b0bbcd404ebecb202cff2a18e (patch)
treeec93caf1ed9238b91e7ec6a1c1e470441860f6fc /remote_test.go
parent54afccfa0f5a5574525cbba3b4568cbda252a3df (diff)
Add `NewCredentialSSHKeyFromSigner` (#706)
This change adds `NewCredentialSSHKeyFromSigner`, which allows idiomatic use of SSH keys from Go. This also lets us spin off an SSH server in the tests.
Diffstat (limited to 'remote_test.go')
-rw-r--r--remote_test.go276
1 files changed, 276 insertions, 0 deletions
diff --git a/remote_test.go b/remote_test.go
index 4cc3298..b97d764 100644
--- a/remote_test.go
+++ b/remote_test.go
@@ -1,8 +1,21 @@
package git
import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
"fmt"
+ "io"
+ "net"
+ "os"
+ "os/exec"
+ "strings"
+ "sync"
"testing"
+ "time"
+
+ "github.com/google/shlex"
+ "golang.org/x/crypto/ssh"
)
func TestListRemotes(t *testing.T) {
@@ -184,3 +197,266 @@ func TestRemotePrune(t *testing.T) {
t.Fatal("Expected error getting a pruned reference")
}
}
+
+func newChannelPipe(t *testing.T, w io.Writer, wg *sync.WaitGroup) (*os.File, error) {
+ pr, pw, err := os.Pipe()
+ if err != nil {
+ return nil, err
+ }
+
+ wg.Add(1)
+ go func() {
+ _, err := io.Copy(w, pr)
+ if err != nil && err != io.EOF {
+ t.Logf("Failed to copy: %v", err)
+ }
+ wg.Done()
+ }()
+
+ return pw, nil
+}
+
+func startSSHServer(t *testing.T, hostKey ssh.Signer, authorizedKeys []ssh.PublicKey) net.Listener {
+ t.Helper()
+
+ marshaledAuthorizedKeys := make([][]byte, len(authorizedKeys))
+ for i, authorizedKey := range authorizedKeys {
+ marshaledAuthorizedKeys[i] = authorizedKey.Marshal()
+ }
+
+ config := &ssh.ServerConfig{
+ PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
+ marshaledPubKey := pubKey.Marshal()
+ for _, marshaledAuthorizedKey := range marshaledAuthorizedKeys {
+ if bytes.Equal(marshaledPubKey, marshaledAuthorizedKey) {
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "pubkey-fp": ssh.FingerprintSHA256(pubKey),
+ },
+ }, nil
+ }
+ }
+ t.Logf("unknown public key for %q:\n\t%+v\n\t%+v\n", c.User(), pubKey.Marshal(), authorizedKeys)
+ return nil, fmt.Errorf("unknown public key for %q", c.User())
+ },
+ }
+ config.AddHostKey(hostKey)
+
+ listener, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Failed to listen for connection: %v", err)
+ }
+
+ go func() {
+ nConn, err := listener.Accept()
+ if err != nil {
+ if strings.Contains(err.Error(), "use of closed network connection") {
+ return
+ }
+ t.Logf("Failed to accept incoming connection: %v", err)
+ return
+ }
+ defer nConn.Close()
+
+ conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
+ if err != nil {
+ t.Logf("failed to handshake: %+v, %+v", conn, err)
+ return
+ }
+
+ // The incoming Request channel must be serviced.
+ go func() {
+ for newRequest := range reqs {
+ t.Logf("new request %v", newRequest)
+ }
+ }()
+
+ // Service only the first channel request
+ newChannel := <-chans
+ defer func() {
+ for newChannel := range chans {
+ t.Logf("new channel %v", newChannel)
+ newChannel.Reject(ssh.UnknownChannelType, "server closing")
+ }
+ }()
+
+ // Channels have a type, depending on the application level
+ // protocol intended. In the case of a shell, the type is
+ // "session" and ServerShell may be used to present a simple
+ // terminal interface.
+ if newChannel.ChannelType() != "session" {
+ newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
+ return
+ }
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ t.Logf("Could not accept channel: %v", err)
+ return
+ }
+ defer channel.Close()
+
+ // Sessions have out-of-band requests such as "shell",
+ // "pty-req" and "env". Here we handle only the
+ // "exec" request.
+ req := <-requests
+ if req.Type != "exec" {
+ req.Reply(false, nil)
+ return
+ }
+ // RFC 4254 Section 6.5.
+ var payload struct {
+ Command string
+ }
+ if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
+ t.Logf("invalid payload on channel %v: %v", channel, err)
+ req.Reply(false, nil)
+ return
+ }
+ args, err := shlex.Split(payload.Command)
+ if err != nil {
+ t.Logf("invalid command on channel %v: %v", channel, err)
+ req.Reply(false, nil)
+ return
+ }
+ if len(args) < 2 || (args[0] != "git-upload-pack" && args[0] != "git-receive-pack") {
+ t.Logf("invalid command (%v) on channel %v: %v", args, channel, err)
+ req.Reply(false, nil)
+ return
+ }
+ req.Reply(true, nil)
+
+ go func(in <-chan *ssh.Request) {
+ for req := range in {
+ t.Logf("draining request %v", req)
+ }
+ }(requests)
+
+ // The first parameter is the (absolute) path of the repository.
+ args[1] = "./testdata" + args[1]
+
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Stdin = channel
+ var wg sync.WaitGroup
+ stdoutPipe, err := newChannelPipe(t, channel, &wg)
+ if err != nil {
+ t.Logf("Failed to create stdout pipe: %v", err)
+ return
+ }
+ cmd.Stdout = stdoutPipe
+ stderrPipe, err := newChannelPipe(t, channel.Stderr(), &wg)
+ if err != nil {
+ t.Logf("Failed to create stderr pipe: %v", err)
+ return
+ }
+ cmd.Stderr = stderrPipe
+
+ go func() {
+ wg.Wait()
+ channel.CloseWrite()
+ }()
+
+ err = cmd.Start()
+ if err != nil {
+ t.Logf("Failed to start %v: %v", args, err)
+ return
+ }
+
+ // Once the process has started, we need to close the write end of the
+ // pipes from this process so that we can know when the child has done
+ // writing to it.
+ stdoutPipe.Close()
+ stderrPipe.Close()
+
+ timer := time.AfterFunc(5*time.Second, func() {
+ t.Log("process timed out, terminating")
+ cmd.Process.Kill()
+ })
+ defer timer.Stop()
+
+ err = cmd.Wait()
+ if err != nil {
+ t.Logf("Failed to run %v: %v", args, err)
+ return
+ }
+ }()
+ return listener
+}
+
+func TestRemoteSSH(t *testing.T) {
+ t.Parallel()
+ pubKeyUsername := "testuser"
+
+ hostPrivKey, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ t.Fatalf("Failed to generate the host RSA private key: %v", err)
+ }
+ hostSigner, err := ssh.NewSignerFromKey(hostPrivKey)
+ if err != nil {
+ t.Fatalf("Failed to generate SSH hostSigner: %v", err)
+ }
+
+ privKey, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ t.Fatalf("Failed to generate the user RSA private key: %v", err)
+ }
+ signer, err := ssh.NewSignerFromKey(privKey)
+ if err != nil {
+ t.Fatalf("Failed to generate SSH signer: %v", err)
+ }
+ // This is in the format "xx:xx:xx:...", so we remove the colons so that it
+ // matches the fmt.Sprintf() below.
+ // Note that not all libssh2 implementations support the SHA256 fingerprint,
+ // so we use MD5 here for testing.
+ publicKeyFingerprint := strings.Replace(ssh.FingerprintLegacyMD5(hostSigner.PublicKey()), ":", "", -1)
+
+ listener := startSSHServer(t, hostSigner, []ssh.PublicKey{signer.PublicKey()})
+ defer listener.Close()
+
+ repo := createTestRepo(t)
+ defer cleanupTestRepo(t, repo)
+
+ certificateCheckCallbackCalled := false
+ fetchOpts := FetchOptions{
+ RemoteCallbacks: RemoteCallbacks{
+ CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode {
+ hostkeyFingerprint := fmt.Sprintf("%x", cert.Hostkey.HashMD5[:])
+ if hostkeyFingerprint != publicKeyFingerprint {
+ t.Logf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint)
+ return ErrorCodeAuth
+ }
+ certificateCheckCallbackCalled = true
+ return ErrorCodeOK
+ },
+ CredentialsCallback: func(url, username string, allowedTypes CredentialType) (*Credential, error) {
+ if allowedTypes&(CredentialTypeSSHKey|CredentialTypeSSHCustom|CredentialTypeSSHMemory) != 0 {
+ return NewCredentialSSHKeyFromSigner(pubKeyUsername, signer)
+ }
+ if (allowedTypes & CredentialTypeUsername) != 0 {
+ return NewCredentialUsername(pubKeyUsername)
+ }
+ return nil, fmt.Errorf("unknown credential type %+v", allowedTypes)
+ },
+ },
+ }
+
+ remote, err := repo.Remotes.Create(
+ "origin",
+ fmt.Sprintf("ssh://%s/TestGitRepository", listener.Addr().String()),
+ )
+ checkFatal(t, err)
+ defer remote.Free()
+
+ err = remote.Fetch(nil, &fetchOpts, "")
+ checkFatal(t, err)
+ if !certificateCheckCallbackCalled {
+ t.Fatalf("CertificateCheckCallback was not called")
+ }
+
+ heads, err := remote.Ls()
+ checkFatal(t, err)
+
+ if len(heads) == 0 {
+ t.Error("Expected remote heads")
+ }
+}