summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tree.go18
-rw-r--r--tree_test.go74
2 files changed, 90 insertions, 2 deletions
diff --git a/tree.go b/tree.go
index b1aeaa7..1410b96 100644
--- a/tree.go
+++ b/tree.go
@@ -8,6 +8,7 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr);
import "C"
import (
+ "errors"
"runtime"
"unsafe"
)
@@ -134,7 +135,9 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
}
err := data.callback(C.GoString(_root), newTreeEntry(entry))
- if err != nil {
+ if err == TreeWalkSkip {
+ return C.int(1)
+ } else if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}
@@ -142,6 +145,19 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
return C.int(ErrorCodeOK)
}
+// TreeWalkSkip is an error that can be returned form TreeWalkCallback to skip
+// a subtree from being expanded.
+var TreeWalkSkip = errors.New("skip")
+
+// Walk traverses the entries in a tree and its subtrees in pre order.
+//
+// The entries will be traversed in the pre order, children subtrees will be
+// automatically loaded as required, and the callback will be called once per
+// entry with the current (relative) root for the entry and the entry data
+// itself.
+//
+// If the callback returns TreeWalkSkip, the passed entry will be skipped on
+// the traversal. Any other non-nil error stops the walk.
func (t *Tree) Walk(callback TreeWalkCallback) error {
var err error
data := treeWalkCallbackData{
diff --git a/tree_test.go b/tree_test.go
index f5b6822..0c0c004 100644
--- a/tree_test.go
+++ b/tree_test.go
@@ -1,6 +1,9 @@
package git
-import "testing"
+import (
+ "errors"
+ "testing"
+)
func TestTreeEntryById(t *testing.T) {
t.Parallel()
@@ -63,3 +66,72 @@ func TestTreeBuilderInsert(t *testing.T) {
t.Fatalf("got oid %v, want %v", entry.Id, blobId)
}
}
+
+func TestTreeWalk(t *testing.T) {
+ t.Parallel()
+ repo, err := OpenRepository("testdata/TestGitRepository.git")
+ checkFatal(t, err)
+ treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
+ checkFatal(t, err)
+
+ tree, err := repo.LookupTree(treeID)
+ checkFatal(t, err)
+
+ var callCount int
+ err = tree.Walk(func(name string, entry *TreeEntry) error {
+ callCount++
+
+ return nil
+ })
+ checkFatal(t, err)
+ if callCount != 11 {
+ t.Fatalf("got called %v times, want %v", callCount, 11)
+ }
+}
+
+func TestTreeWalkSkip(t *testing.T) {
+ t.Parallel()
+ repo, err := OpenRepository("testdata/TestGitRepository.git")
+ checkFatal(t, err)
+ treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
+ checkFatal(t, err)
+
+ tree, err := repo.LookupTree(treeID)
+ checkFatal(t, err)
+
+ var callCount int
+ err = tree.Walk(func(name string, entry *TreeEntry) error {
+ callCount++
+
+ return TreeWalkSkip
+ })
+ checkFatal(t, err)
+ if callCount != 4 {
+ t.Fatalf("got called %v times, want %v", callCount, 4)
+ }
+}
+
+func TestTreeWalkStop(t *testing.T) {
+ t.Parallel()
+ repo, err := OpenRepository("testdata/TestGitRepository.git")
+ checkFatal(t, err)
+ treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
+ checkFatal(t, err)
+
+ tree, err := repo.LookupTree(treeID)
+ checkFatal(t, err)
+
+ var callCount int
+ stopError := errors.New("stop")
+ err = tree.Walk(func(name string, entry *TreeEntry) error {
+ callCount++
+
+ return stopError
+ })
+ if err != stopError {
+ t.Fatalf("got error %v, want %v", err, stopError)
+ }
+ if callCount != 1 {
+ t.Fatalf("got called %v times, want %v", callCount, 1)
+ }
+}