summaryrefslogtreecommitdiff
path: root/remote.go
diff options
context:
space:
mode:
authorlhchavez <[email protected]>2020-12-10 07:19:41 -0800
committerGitHub <[email protected]>2020-12-10 07:19:41 -0800
commit10c67474a89c298172a6703b91980ea37c60d5e5 (patch)
tree8b32fd2ce9540e01e90ddd09b59969d85832dc25 /remote.go
parente28cce87c7551bffa1f4602ff492348f9a8cba60 (diff)
More callback refactoring (#713)
This change: * Gets rid of the `.toC()` functions for Options objects, since they were redundant with the `populateXxxOptions()`. * Adds support for `errorTarget` to the `RemoteOptions`, since they are used in the same stack for some functions (like `Fetch()`). Now for those cases, the error returned by the callback will be preserved as-is.
Diffstat (limited to 'remote.go')
-rw-r--r--remote.go309
1 files changed, 198 insertions, 111 deletions
diff --git a/remote.go b/remote.go
index b1c8532..1887f79 100644
--- a/remote.go
+++ b/remote.go
@@ -7,7 +7,6 @@ package git
#include <git2/sys/cred.h>
extern void _go_git_populate_remote_callbacks(git_remote_callbacks *callbacks);
-
*/
import "C"
import (
@@ -72,6 +71,11 @@ type RemoteCallbacks struct {
PushUpdateReferenceCallback
}
+type remoteCallbacksData struct {
+ callbacks *RemoteCallbacks
+ errorTarget *error
+}
+
type FetchPrune uint
const (
@@ -86,7 +90,6 @@ const (
type DownloadTags uint
const (
-
// Use the setting from the configuration.
DownloadTagsUnspecified DownloadTags = C.GIT_REMOTE_DOWNLOAD_TAGS_UNSPECIFIED
// Ask the server for tags pointing to objects we're already
@@ -209,44 +212,58 @@ func newRemoteHeadFromC(ptr *C.git_remote_head) RemoteHead {
}
}
-func untrackCalbacksPayload(callbacks *C.git_remote_callbacks) {
- if callbacks != nil && callbacks.payload != nil {
- pointerHandles.Untrack(callbacks.payload)
+func untrackCallbacksPayload(callbacks *C.git_remote_callbacks) {
+ if callbacks == nil || callbacks.payload == nil {
+ return
}
+ pointerHandles.Untrack(callbacks.payload)
}
-func populateRemoteCallbacks(ptr *C.git_remote_callbacks, callbacks *RemoteCallbacks) {
+func populateRemoteCallbacks(ptr *C.git_remote_callbacks, callbacks *RemoteCallbacks, errorTarget *error) *C.git_remote_callbacks {
C.git_remote_init_callbacks(ptr, C.GIT_REMOTE_CALLBACKS_VERSION)
if callbacks == nil {
- return
+ return ptr
}
C._go_git_populate_remote_callbacks(ptr)
- ptr.payload = pointerHandles.Track(callbacks)
+ data := &remoteCallbacksData{
+ callbacks: callbacks,
+ errorTarget: errorTarget,
+ }
+ ptr.payload = pointerHandles.Track(data)
+ return ptr
}
//export sidebandProgressCallback
-func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, data unsafe.Pointer) C.int {
- callbacks := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.SidebandProgressCallback == nil {
+func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, handle unsafe.Pointer) C.int {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.SidebandProgressCallback == nil {
return C.int(ErrorCodeOK)
}
str := C.GoStringN(_str, _len)
- ret := callbacks.SidebandProgressCallback(str)
+ ret := data.callbacks.SidebandProgressCallback(str)
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
//export completionCallback
-func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, data unsafe.Pointer) C.int {
- callbacks := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.CompletionCallback == nil {
+func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, handle unsafe.Pointer) C.int {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.CompletionCallback == nil {
return C.int(ErrorCodeOK)
}
- ret := callbacks.CompletionCallback(RemoteCompletion(completion_type))
+ ret := data.callbacks.CompletionCallback(RemoteCompletion(completion_type))
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
@@ -258,16 +275,19 @@ func credentialsCallback(
_url *C.char,
_username_from_url *C.char,
allowed_types uint,
- data unsafe.Pointer,
+ handle unsafe.Pointer,
) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.CredentialsCallback == nil {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.CredentialsCallback == nil {
return C.int(ErrorCodePassthrough)
}
url := C.GoString(_url)
username_from_url := C.GoString(_username_from_url)
- cred, err := callbacks.CredentialsCallback(url, username_from_url, (CredentialType)(allowed_types))
+ cred, err := data.callbacks.CredentialsCallback(url, username_from_url, (CredentialType)(allowed_types))
if err != nil {
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
return setCallbackError(errorMessage, err)
}
if cred != nil {
@@ -281,14 +301,18 @@ func credentialsCallback(
}
//export transferProgressCallback
-func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progress, data unsafe.Pointer) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.TransferProgressCallback == nil {
+func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progress, handle unsafe.Pointer) C.int {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.TransferProgressCallback == nil {
return C.int(ErrorCodeOK)
}
- ret := callbacks.TransferProgressCallback(newTransferProgressFromC(stats))
+ ret := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats))
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
@@ -299,18 +323,22 @@ func updateTipsCallback(
_refname *C.char,
_a *C.git_oid,
_b *C.git_oid,
- data unsafe.Pointer,
+ handle unsafe.Pointer,
) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.UpdateTipsCallback == nil {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.UpdateTipsCallback == nil {
return C.int(ErrorCodeOK)
}
refname := C.GoString(_refname)
a := newOidFromC(_a)
b := newOidFromC(_b)
- ret := callbacks.UpdateTipsCallback(refname, a, b)
+ ret := data.callbacks.UpdateTipsCallback(refname, a, b)
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
@@ -321,11 +349,11 @@ func certificateCheckCallback(
_cert *C.git_cert,
_valid C.int,
_host *C.char,
- data unsafe.Pointer,
+ handle unsafe.Pointer,
) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
// if there's no callback set, we need to make sure we fail if the library didn't consider this cert valid
- if callbacks.CertificateCheckCallback == nil {
+ if data.callbacks.CertificateCheckCallback == nil {
if _valid == 0 {
return C.int(ErrorCodeCertificate)
}
@@ -341,10 +369,17 @@ func certificateCheckCallback(
ccert := (*C.git_cert_x509)(unsafe.Pointer(_cert))
x509_certs, err := x509.ParseCertificates(C.GoBytes(ccert.data, C.int(ccert.len)))
if err != nil {
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
return setCallbackError(errorMessage, err)
}
if len(x509_certs) < 1 {
- return setCallbackError(errorMessage, errors.New("empty certificate list"))
+ err := errors.New("empty certificate list")
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
// we assume there's only one, which should hold true for any web server we want to talk to
@@ -357,74 +392,95 @@ func certificateCheckCallback(
C.memcpy(unsafe.Pointer(&cert.Hostkey.HashSHA1[0]), unsafe.Pointer(&ccert.hash_sha1[0]), C.size_t(len(cert.Hostkey.HashSHA1)))
C.memcpy(unsafe.Pointer(&cert.Hostkey.HashSHA256[0]), unsafe.Pointer(&ccert.hash_sha256[0]), C.size_t(len(cert.Hostkey.HashSHA256)))
} else {
- return setCallbackError(errorMessage, errors.New("unsupported certificate type"))
+ err := errors.New("unsupported certificate type")
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
- ret := callbacks.CertificateCheckCallback(&cert, valid, host)
+ ret := data.callbacks.CertificateCheckCallback(&cert, valid, host)
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
//export packProgressCallback
-func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.uint, data unsafe.Pointer) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.PackProgressCallback == nil {
+func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.uint, handle unsafe.Pointer) C.int {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.PackProgressCallback == nil {
return C.int(ErrorCodeOK)
}
- ret := callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total))
+ ret := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total))
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
//export pushTransferProgressCallback
-func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint, bytes C.size_t, data unsafe.Pointer) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.PushTransferProgressCallback == nil {
+func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint, bytes C.size_t, handle unsafe.Pointer) C.int {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.PushTransferProgressCallback == nil {
return C.int(ErrorCodeOK)
}
- ret := callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes))
+ ret := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes))
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
//export pushUpdateReferenceCallback
-func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char, data unsafe.Pointer) C.int {
- callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks)
- if callbacks.PushUpdateReferenceCallback == nil {
+func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char, handle unsafe.Pointer) C.int {
+ data := pointerHandles.Get(handle).(*remoteCallbacksData)
+ if data.callbacks.PushUpdateReferenceCallback == nil {
return C.int(ErrorCodeOK)
}
- ret := callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status))
+ ret := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status))
if ret < 0 {
- return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String()))
+ err := errors.New(ErrorCode(ret).String())
+ if data.errorTarget != nil {
+ *data.errorTarget = err
+ }
+ return setCallbackError(errorMessage, err)
}
return C.int(ErrorCodeOK)
}
-func populateProxyOptions(ptr *C.git_proxy_options, opts *ProxyOptions) {
- C.git_proxy_options_init(ptr, C.GIT_PROXY_OPTIONS_VERSION)
+func populateProxyOptions(copts *C.git_proxy_options, opts *ProxyOptions) *C.git_proxy_options {
+ C.git_proxy_options_init(copts, C.GIT_PROXY_OPTIONS_VERSION)
if opts == nil {
- return
+ return nil
}
- ptr._type = C.git_proxy_t(opts.Type)
- ptr.url = C.CString(opts.Url)
+ copts._type = C.git_proxy_t(opts.Type)
+ copts.url = C.CString(opts.Url)
+ return copts
}
-func freeProxyOptions(ptr *C.git_proxy_options) {
- if ptr == nil {
+func freeProxyOptions(copts *C.git_proxy_options) {
+ if copts == nil {
return
}
- C.free(unsafe.Pointer(ptr.url))
+ C.free(unsafe.Pointer(copts.url))
}
// RemoteIsValidName returns whether the remote name is well-formed.
@@ -738,35 +794,54 @@ func (o *Remote) RefspecCount() uint {
return uint(count)
}
-func populateFetchOptions(options *C.git_fetch_options, opts *FetchOptions) {
- C.git_fetch_options_init(options, C.GIT_FETCH_OPTIONS_VERSION)
+func populateFetchOptions(copts *C.git_fetch_options, opts *FetchOptions, errorTarget *error) *C.git_fetch_options {
+ C.git_fetch_options_init(copts, C.GIT_FETCH_OPTIONS_VERSION)
if opts == nil {
- return
+ return nil
}
- populateRemoteCallbacks(&options.callbacks, &opts.RemoteCallbacks)
- options.prune = C.git_fetch_prune_t(opts.Prune)
- options.update_fetchhead = cbool(opts.UpdateFetchhead)
- options.download_tags = C.git_remote_autotag_option_t(opts.DownloadTags)
+ populateRemoteCallbacks(&copts.callbacks, &opts.RemoteCallbacks, errorTarget)
+ copts.prune = C.git_fetch_prune_t(opts.Prune)
+ copts.update_fetchhead = cbool(opts.UpdateFetchhead)
+ copts.download_tags = C.git_remote_autotag_option_t(opts.DownloadTags)
- options.custom_headers = C.git_strarray{}
- options.custom_headers.count = C.size_t(len(opts.Headers))
- options.custom_headers.strings = makeCStringsFromStrings(opts.Headers)
- populateProxyOptions(&options.proxy_opts, &opts.ProxyOptions)
+ copts.custom_headers = C.git_strarray{
+ count: C.size_t(len(opts.Headers)),
+ strings: makeCStringsFromStrings(opts.Headers),
+ }
+ populateProxyOptions(&copts.proxy_opts, &opts.ProxyOptions)
+ return copts
}
-func populatePushOptions(options *C.git_push_options, opts *PushOptions) {
- C.git_push_options_init(options, C.GIT_PUSH_OPTIONS_VERSION)
- if opts == nil {
+func freeFetchOptions(copts *C.git_fetch_options) {
+ if copts == nil {
return
}
+ freeStrarray(&copts.custom_headers)
+ untrackCallbacksPayload(&copts.callbacks)
+ freeProxyOptions(&copts.proxy_opts)
+}
- options.pb_parallelism = C.uint(opts.PbParallelism)
+func populatePushOptions(copts *C.git_push_options, opts *PushOptions, errorTarget *error) *C.git_push_options {
+ C.git_push_options_init(copts, C.GIT_PUSH_OPTIONS_VERSION)
+ if opts == nil {
+ return nil
+ }
- options.custom_headers = C.git_strarray{}
- options.custom_headers.count = C.size_t(len(opts.Headers))
- options.custom_headers.strings = makeCStringsFromStrings(opts.Headers)
+ copts.pb_parallelism = C.uint(opts.PbParallelism)
+ copts.custom_headers = C.git_strarray{
+ count: C.size_t(len(opts.Headers)),
+ strings: makeCStringsFromStrings(opts.Headers),
+ }
+ populateRemoteCallbacks(&copts.callbacks, &opts.RemoteCallbacks, errorTarget)
+ return copts
+}
- populateRemoteCallbacks(&options.callbacks, &opts.RemoteCallbacks)
+func freePushOptions(copts *C.git_push_options) {
+ if copts == nil {
+ return
+ }
+ untrackCallbacksPayload(&copts.callbacks)
+ freeStrarray(&copts.custom_headers)
}
// Fetch performs a fetch operation. refspecs specifies which refspecs
@@ -780,26 +855,29 @@ func (o *Remote) Fetch(refspecs []string, opts *FetchOptions, msg string) error
defer C.free(unsafe.Pointer(cmsg))
}
- crefspecs := C.git_strarray{}
- crefspecs.count = C.size_t(len(refspecs))
- crefspecs.strings = makeCStringsFromStrings(refspecs)
+ var err error
+ crefspecs := C.git_strarray{
+ count: C.size_t(len(refspecs)),
+ strings: makeCStringsFromStrings(refspecs),
+ }
defer freeStrarray(&crefspecs)
- coptions := (*C.git_fetch_options)(C.calloc(1, C.size_t(unsafe.Sizeof(C.git_fetch_options{}))))
- defer C.free(unsafe.Pointer(coptions))
-
- populateFetchOptions(coptions, opts)
- defer untrackCalbacksPayload(&coptions.callbacks)
- defer freeStrarray(&coptions.custom_headers)
+ coptions := populateFetchOptions(&C.git_fetch_options{}, opts, &err)
+ defer freeFetchOptions(coptions)
runtime.LockOSThread()
defer runtime.UnlockOSThread()
ret := C.git_remote_fetch(o.ptr, &crefspecs, coptions, cmsg)
runtime.KeepAlive(o)
+
+ if ret == C.int(ErrorCodeUser) && err != nil {
+ return err
+ }
if ret < 0 {
return MakeGitError(ret)
}
+
return nil
}
@@ -819,23 +897,27 @@ func (o *Remote) ConnectPush(callbacks *RemoteCallbacks, proxyOpts *ProxyOptions
//
// 'headers' are extra HTTP headers to use in this connection.
func (o *Remote) Connect(direction ConnectDirection, callbacks *RemoteCallbacks, proxyOpts *ProxyOptions, headers []string) error {
- var ccallbacks C.git_remote_callbacks
- populateRemoteCallbacks(&ccallbacks, callbacks)
+ var err error
+ ccallbacks := populateRemoteCallbacks(&C.git_remote_callbacks{}, callbacks, &err)
+ defer untrackCallbacksPayload(ccallbacks)
- var cproxy C.git_proxy_options
- populateProxyOptions(&cproxy, proxyOpts)
- defer freeProxyOptions(&cproxy)
+ cproxy := populateProxyOptions(&C.git_proxy_options{}, proxyOpts)
+ defer freeProxyOptions(cproxy)
- cheaders := C.git_strarray{}
- cheaders.count = C.size_t(len(headers))
- cheaders.strings = makeCStringsFromStrings(headers)
+ cheaders := C.git_strarray{
+ count: C.size_t(len(headers)),
+ strings: makeCStringsFromStrings(headers),
+ }
defer freeStrarray(&cheaders)
runtime.LockOSThread()
defer runtime.UnlockOSThread()
- ret := C.git_remote_connect(o.ptr, C.git_direction(direction), &ccallbacks, &cproxy, &cheaders)
+ ret := C.git_remote_connect(o.ptr, C.git_direction(direction), ccallbacks, cproxy, &cheaders)
runtime.KeepAlive(o)
+ if ret == C.int(ErrorCodeUser) && err != nil {
+ return err
+ }
if ret != 0 {
return MakeGitError(ret)
}
@@ -899,23 +981,24 @@ func (o *Remote) Ls(filterRefs ...string) ([]RemoteHead, error) {
}
func (o *Remote) Push(refspecs []string, opts *PushOptions) error {
- crefspecs := C.git_strarray{}
- crefspecs.count = C.size_t(len(refspecs))
- crefspecs.strings = makeCStringsFromStrings(refspecs)
+ crefspecs := C.git_strarray{
+ count: C.size_t(len(refspecs)),
+ strings: makeCStringsFromStrings(refspecs),
+ }
defer freeStrarray(&crefspecs)
- coptions := (*C.git_push_options)(C.calloc(1, C.size_t(unsafe.Sizeof(C.git_push_options{}))))
- defer C.free(unsafe.Pointer(coptions))
-
- populatePushOptions(coptions, opts)
- defer untrackCalbacksPayload(&coptions.callbacks)
- defer freeStrarray(&coptions.custom_headers)
+ var err error
+ coptions := populatePushOptions(&C.git_push_options{}, opts, &err)
+ defer freePushOptions(coptions)
runtime.LockOSThread()
defer runtime.UnlockOSThread()
ret := C.git_remote_push(o.ptr, &crefspecs, coptions)
runtime.KeepAlive(o)
+ if ret == C.int(ErrorCodeUser) && err != nil {
+ return err
+ }
if ret < 0 {
return MakeGitError(ret)
}
@@ -927,14 +1010,18 @@ func (o *Remote) PruneRefs() bool {
}
func (o *Remote) Prune(callbacks *RemoteCallbacks) error {
- var ccallbacks C.git_remote_callbacks
- populateRemoteCallbacks(&ccallbacks, callbacks)
+ var err error
+ ccallbacks := populateRemoteCallbacks(&C.git_remote_callbacks{}, callbacks, &err)
+ defer untrackCallbacksPayload(ccallbacks)
runtime.LockOSThread()
defer runtime.UnlockOSThread()
- ret := C.git_remote_prune(o.ptr, &ccallbacks)
+ ret := C.git_remote_prune(o.ptr, ccallbacks)
runtime.KeepAlive(o)
+ if ret == C.int(ErrorCodeUser) && err != nil {
+ return err
+ }
if ret < 0 {
return MakeGitError(ret)
}