diff options
| author | lhchavez <[email protected]> | 2020-12-06 11:55:04 -0800 |
|---|---|---|
| committer | GitHub <[email protected]> | 2020-12-06 11:55:04 -0800 |
| commit | abf02bc7d79dfb7b0bbcd404ebecb202cff2a18e (patch) | |
| tree | ec93caf1ed9238b91e7ec6a1c1e470441860f6fc /remote_test.go | |
| parent | 54afccfa0f5a5574525cbba3b4568cbda252a3df (diff) | |
Add `NewCredentialSSHKeyFromSigner` (#706)
This change adds `NewCredentialSSHKeyFromSigner`, which allows idiomatic
use of SSH keys from Go. This also lets us spin off an SSH server in the
tests.
Diffstat (limited to 'remote_test.go')
| -rw-r--r-- | remote_test.go | 276 |
1 files changed, 276 insertions, 0 deletions
diff --git a/remote_test.go b/remote_test.go index 4cc3298..b97d764 100644 --- a/remote_test.go +++ b/remote_test.go @@ -1,8 +1,21 @@ package git import ( + "bytes" + "crypto/rand" + "crypto/rsa" "fmt" + "io" + "net" + "os" + "os/exec" + "strings" + "sync" "testing" + "time" + + "github.com/google/shlex" + "golang.org/x/crypto/ssh" ) func TestListRemotes(t *testing.T) { @@ -184,3 +197,266 @@ func TestRemotePrune(t *testing.T) { t.Fatal("Expected error getting a pruned reference") } } + +func newChannelPipe(t *testing.T, w io.Writer, wg *sync.WaitGroup) (*os.File, error) { + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + + wg.Add(1) + go func() { + _, err := io.Copy(w, pr) + if err != nil && err != io.EOF { + t.Logf("Failed to copy: %v", err) + } + wg.Done() + }() + + return pw, nil +} + +func startSSHServer(t *testing.T, hostKey ssh.Signer, authorizedKeys []ssh.PublicKey) net.Listener { + t.Helper() + + marshaledAuthorizedKeys := make([][]byte, len(authorizedKeys)) + for i, authorizedKey := range authorizedKeys { + marshaledAuthorizedKeys[i] = authorizedKey.Marshal() + } + + config := &ssh.ServerConfig{ + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + marshaledPubKey := pubKey.Marshal() + for _, marshaledAuthorizedKey := range marshaledAuthorizedKeys { + if bytes.Equal(marshaledPubKey, marshaledAuthorizedKey) { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + } + t.Logf("unknown public key for %q:\n\t%+v\n\t%+v\n", c.User(), pubKey.Marshal(), authorizedKeys) + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + config.AddHostKey(hostKey) + + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen for connection: %v", err) + } + + go func() { + nConn, err := listener.Accept() + if err != nil { + if strings.Contains(err.Error(), "use of closed network connection") { + return + } + t.Logf("Failed to accept incoming connection: %v", err) + return + } + defer nConn.Close() + + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + t.Logf("failed to handshake: %+v, %+v", conn, err) + return + } + + // The incoming Request channel must be serviced. + go func() { + for newRequest := range reqs { + t.Logf("new request %v", newRequest) + } + }() + + // Service only the first channel request + newChannel := <-chans + defer func() { + for newChannel := range chans { + t.Logf("new channel %v", newChannel) + newChannel.Reject(ssh.UnknownChannelType, "server closing") + } + }() + + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + return + } + channel, requests, err := newChannel.Accept() + if err != nil { + t.Logf("Could not accept channel: %v", err) + return + } + defer channel.Close() + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "exec" request. + req := <-requests + if req.Type != "exec" { + req.Reply(false, nil) + return + } + // RFC 4254 Section 6.5. + var payload struct { + Command string + } + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + t.Logf("invalid payload on channel %v: %v", channel, err) + req.Reply(false, nil) + return + } + args, err := shlex.Split(payload.Command) + if err != nil { + t.Logf("invalid command on channel %v: %v", channel, err) + req.Reply(false, nil) + return + } + if len(args) < 2 || (args[0] != "git-upload-pack" && args[0] != "git-receive-pack") { + t.Logf("invalid command (%v) on channel %v: %v", args, channel, err) + req.Reply(false, nil) + return + } + req.Reply(true, nil) + + go func(in <-chan *ssh.Request) { + for req := range in { + t.Logf("draining request %v", req) + } + }(requests) + + // The first parameter is the (absolute) path of the repository. + args[1] = "./testdata" + args[1] + + cmd := exec.Command(args[0], args[1:]...) + cmd.Stdin = channel + var wg sync.WaitGroup + stdoutPipe, err := newChannelPipe(t, channel, &wg) + if err != nil { + t.Logf("Failed to create stdout pipe: %v", err) + return + } + cmd.Stdout = stdoutPipe + stderrPipe, err := newChannelPipe(t, channel.Stderr(), &wg) + if err != nil { + t.Logf("Failed to create stderr pipe: %v", err) + return + } + cmd.Stderr = stderrPipe + + go func() { + wg.Wait() + channel.CloseWrite() + }() + + err = cmd.Start() + if err != nil { + t.Logf("Failed to start %v: %v", args, err) + return + } + + // Once the process has started, we need to close the write end of the + // pipes from this process so that we can know when the child has done + // writing to it. + stdoutPipe.Close() + stderrPipe.Close() + + timer := time.AfterFunc(5*time.Second, func() { + t.Log("process timed out, terminating") + cmd.Process.Kill() + }) + defer timer.Stop() + + err = cmd.Wait() + if err != nil { + t.Logf("Failed to run %v: %v", args, err) + return + } + }() + return listener +} + +func TestRemoteSSH(t *testing.T) { + t.Parallel() + pubKeyUsername := "testuser" + + hostPrivKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatalf("Failed to generate the host RSA private key: %v", err) + } + hostSigner, err := ssh.NewSignerFromKey(hostPrivKey) + if err != nil { + t.Fatalf("Failed to generate SSH hostSigner: %v", err) + } + + privKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatalf("Failed to generate the user RSA private key: %v", err) + } + signer, err := ssh.NewSignerFromKey(privKey) + if err != nil { + t.Fatalf("Failed to generate SSH signer: %v", err) + } + // This is in the format "xx:xx:xx:...", so we remove the colons so that it + // matches the fmt.Sprintf() below. + // Note that not all libssh2 implementations support the SHA256 fingerprint, + // so we use MD5 here for testing. + publicKeyFingerprint := strings.Replace(ssh.FingerprintLegacyMD5(hostSigner.PublicKey()), ":", "", -1) + + listener := startSSHServer(t, hostSigner, []ssh.PublicKey{signer.PublicKey()}) + defer listener.Close() + + repo := createTestRepo(t) + defer cleanupTestRepo(t, repo) + + certificateCheckCallbackCalled := false + fetchOpts := FetchOptions{ + RemoteCallbacks: RemoteCallbacks{ + CertificateCheckCallback: func(cert *Certificate, valid bool, hostname string) ErrorCode { + hostkeyFingerprint := fmt.Sprintf("%x", cert.Hostkey.HashMD5[:]) + if hostkeyFingerprint != publicKeyFingerprint { + t.Logf("server hostkey %q, want %q", hostkeyFingerprint, publicKeyFingerprint) + return ErrorCodeAuth + } + certificateCheckCallbackCalled = true + return ErrorCodeOK + }, + CredentialsCallback: func(url, username string, allowedTypes CredentialType) (*Credential, error) { + if allowedTypes&(CredentialTypeSSHKey|CredentialTypeSSHCustom|CredentialTypeSSHMemory) != 0 { + return NewCredentialSSHKeyFromSigner(pubKeyUsername, signer) + } + if (allowedTypes & CredentialTypeUsername) != 0 { + return NewCredentialUsername(pubKeyUsername) + } + return nil, fmt.Errorf("unknown credential type %+v", allowedTypes) + }, + }, + } + + remote, err := repo.Remotes.Create( + "origin", + fmt.Sprintf("ssh://%s/TestGitRepository", listener.Addr().String()), + ) + checkFatal(t, err) + defer remote.Free() + + err = remote.Fetch(nil, &fetchOpts, "") + checkFatal(t, err) + if !certificateCheckCallbackCalled { + t.Fatalf("CertificateCheckCallback was not called") + } + + heads, err := remote.Ls() + checkFatal(t, err) + + if len(heads) == 0 { + t.Error("Expected remote heads") + } +} |
