diff options
| author | lhchavez <[email protected]> | 2020-12-10 07:19:41 -0800 |
|---|---|---|
| committer | GitHub <[email protected]> | 2020-12-10 07:19:41 -0800 |
| commit | 10c67474a89c298172a6703b91980ea37c60d5e5 (patch) | |
| tree | 8b32fd2ce9540e01e90ddd09b59969d85832dc25 /remote.go | |
| parent | e28cce87c7551bffa1f4602ff492348f9a8cba60 (diff) | |
More callback refactoring (#713)
This change:
* Gets rid of the `.toC()` functions for Options objects, since they
were redundant with the `populateXxxOptions()`.
* Adds support for `errorTarget` to the `RemoteOptions`, since they are
used in the same stack for some functions (like `Fetch()`). Now for
those cases, the error returned by the callback will be preserved
as-is.
Diffstat (limited to 'remote.go')
| -rw-r--r-- | remote.go | 309 |
1 files changed, 198 insertions, 111 deletions
@@ -7,7 +7,6 @@ package git #include <git2/sys/cred.h> extern void _go_git_populate_remote_callbacks(git_remote_callbacks *callbacks); - */ import "C" import ( @@ -72,6 +71,11 @@ type RemoteCallbacks struct { PushUpdateReferenceCallback } +type remoteCallbacksData struct { + callbacks *RemoteCallbacks + errorTarget *error +} + type FetchPrune uint const ( @@ -86,7 +90,6 @@ const ( type DownloadTags uint const ( - // Use the setting from the configuration. DownloadTagsUnspecified DownloadTags = C.GIT_REMOTE_DOWNLOAD_TAGS_UNSPECIFIED // Ask the server for tags pointing to objects we're already @@ -209,44 +212,58 @@ func newRemoteHeadFromC(ptr *C.git_remote_head) RemoteHead { } } -func untrackCalbacksPayload(callbacks *C.git_remote_callbacks) { - if callbacks != nil && callbacks.payload != nil { - pointerHandles.Untrack(callbacks.payload) +func untrackCallbacksPayload(callbacks *C.git_remote_callbacks) { + if callbacks == nil || callbacks.payload == nil { + return } + pointerHandles.Untrack(callbacks.payload) } -func populateRemoteCallbacks(ptr *C.git_remote_callbacks, callbacks *RemoteCallbacks) { +func populateRemoteCallbacks(ptr *C.git_remote_callbacks, callbacks *RemoteCallbacks, errorTarget *error) *C.git_remote_callbacks { C.git_remote_init_callbacks(ptr, C.GIT_REMOTE_CALLBACKS_VERSION) if callbacks == nil { - return + return ptr } C._go_git_populate_remote_callbacks(ptr) - ptr.payload = pointerHandles.Track(callbacks) + data := &remoteCallbacksData{ + callbacks: callbacks, + errorTarget: errorTarget, + } + ptr.payload = pointerHandles.Track(data) + return ptr } //export sidebandProgressCallback -func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, data unsafe.Pointer) C.int { - callbacks := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.SidebandProgressCallback == nil { +func sidebandProgressCallback(errorMessage **C.char, _str *C.char, _len C.int, handle unsafe.Pointer) C.int { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.SidebandProgressCallback == nil { return C.int(ErrorCodeOK) } str := C.GoStringN(_str, _len) - ret := callbacks.SidebandProgressCallback(str) + ret := data.callbacks.SidebandProgressCallback(str) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } //export completionCallback -func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, data unsafe.Pointer) C.int { - callbacks := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.CompletionCallback == nil { +func completionCallback(errorMessage **C.char, completion_type C.git_remote_completion_type, handle unsafe.Pointer) C.int { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.CompletionCallback == nil { return C.int(ErrorCodeOK) } - ret := callbacks.CompletionCallback(RemoteCompletion(completion_type)) + ret := data.callbacks.CompletionCallback(RemoteCompletion(completion_type)) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } @@ -258,16 +275,19 @@ func credentialsCallback( _url *C.char, _username_from_url *C.char, allowed_types uint, - data unsafe.Pointer, + handle unsafe.Pointer, ) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.CredentialsCallback == nil { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.CredentialsCallback == nil { return C.int(ErrorCodePassthrough) } url := C.GoString(_url) username_from_url := C.GoString(_username_from_url) - cred, err := callbacks.CredentialsCallback(url, username_from_url, (CredentialType)(allowed_types)) + cred, err := data.callbacks.CredentialsCallback(url, username_from_url, (CredentialType)(allowed_types)) if err != nil { + if data.errorTarget != nil { + *data.errorTarget = err + } return setCallbackError(errorMessage, err) } if cred != nil { @@ -281,14 +301,18 @@ func credentialsCallback( } //export transferProgressCallback -func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progress, data unsafe.Pointer) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.TransferProgressCallback == nil { +func transferProgressCallback(errorMessage **C.char, stats *C.git_transfer_progress, handle unsafe.Pointer) C.int { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.TransferProgressCallback == nil { return C.int(ErrorCodeOK) } - ret := callbacks.TransferProgressCallback(newTransferProgressFromC(stats)) + ret := data.callbacks.TransferProgressCallback(newTransferProgressFromC(stats)) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } @@ -299,18 +323,22 @@ func updateTipsCallback( _refname *C.char, _a *C.git_oid, _b *C.git_oid, - data unsafe.Pointer, + handle unsafe.Pointer, ) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.UpdateTipsCallback == nil { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.UpdateTipsCallback == nil { return C.int(ErrorCodeOK) } refname := C.GoString(_refname) a := newOidFromC(_a) b := newOidFromC(_b) - ret := callbacks.UpdateTipsCallback(refname, a, b) + ret := data.callbacks.UpdateTipsCallback(refname, a, b) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } @@ -321,11 +349,11 @@ func certificateCheckCallback( _cert *C.git_cert, _valid C.int, _host *C.char, - data unsafe.Pointer, + handle unsafe.Pointer, ) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) + data := pointerHandles.Get(handle).(*remoteCallbacksData) // if there's no callback set, we need to make sure we fail if the library didn't consider this cert valid - if callbacks.CertificateCheckCallback == nil { + if data.callbacks.CertificateCheckCallback == nil { if _valid == 0 { return C.int(ErrorCodeCertificate) } @@ -341,10 +369,17 @@ func certificateCheckCallback( ccert := (*C.git_cert_x509)(unsafe.Pointer(_cert)) x509_certs, err := x509.ParseCertificates(C.GoBytes(ccert.data, C.int(ccert.len))) if err != nil { + if data.errorTarget != nil { + *data.errorTarget = err + } return setCallbackError(errorMessage, err) } if len(x509_certs) < 1 { - return setCallbackError(errorMessage, errors.New("empty certificate list")) + err := errors.New("empty certificate list") + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } // we assume there's only one, which should hold true for any web server we want to talk to @@ -357,74 +392,95 @@ func certificateCheckCallback( C.memcpy(unsafe.Pointer(&cert.Hostkey.HashSHA1[0]), unsafe.Pointer(&ccert.hash_sha1[0]), C.size_t(len(cert.Hostkey.HashSHA1))) C.memcpy(unsafe.Pointer(&cert.Hostkey.HashSHA256[0]), unsafe.Pointer(&ccert.hash_sha256[0]), C.size_t(len(cert.Hostkey.HashSHA256))) } else { - return setCallbackError(errorMessage, errors.New("unsupported certificate type")) + err := errors.New("unsupported certificate type") + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } - ret := callbacks.CertificateCheckCallback(&cert, valid, host) + ret := data.callbacks.CertificateCheckCallback(&cert, valid, host) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } //export packProgressCallback -func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.uint, data unsafe.Pointer) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.PackProgressCallback == nil { +func packProgressCallback(errorMessage **C.char, stage C.int, current, total C.uint, handle unsafe.Pointer) C.int { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.PackProgressCallback == nil { return C.int(ErrorCodeOK) } - ret := callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total)) + ret := data.callbacks.PackProgressCallback(int32(stage), uint32(current), uint32(total)) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } //export pushTransferProgressCallback -func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint, bytes C.size_t, data unsafe.Pointer) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.PushTransferProgressCallback == nil { +func pushTransferProgressCallback(errorMessage **C.char, current, total C.uint, bytes C.size_t, handle unsafe.Pointer) C.int { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.PushTransferProgressCallback == nil { return C.int(ErrorCodeOK) } - ret := callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes)) + ret := data.callbacks.PushTransferProgressCallback(uint32(current), uint32(total), uint(bytes)) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } //export pushUpdateReferenceCallback -func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char, data unsafe.Pointer) C.int { - callbacks, _ := pointerHandles.Get(data).(*RemoteCallbacks) - if callbacks.PushUpdateReferenceCallback == nil { +func pushUpdateReferenceCallback(errorMessage **C.char, refname, status *C.char, handle unsafe.Pointer) C.int { + data := pointerHandles.Get(handle).(*remoteCallbacksData) + if data.callbacks.PushUpdateReferenceCallback == nil { return C.int(ErrorCodeOK) } - ret := callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status)) + ret := data.callbacks.PushUpdateReferenceCallback(C.GoString(refname), C.GoString(status)) if ret < 0 { - return setCallbackError(errorMessage, errors.New(ErrorCode(ret).String())) + err := errors.New(ErrorCode(ret).String()) + if data.errorTarget != nil { + *data.errorTarget = err + } + return setCallbackError(errorMessage, err) } return C.int(ErrorCodeOK) } -func populateProxyOptions(ptr *C.git_proxy_options, opts *ProxyOptions) { - C.git_proxy_options_init(ptr, C.GIT_PROXY_OPTIONS_VERSION) +func populateProxyOptions(copts *C.git_proxy_options, opts *ProxyOptions) *C.git_proxy_options { + C.git_proxy_options_init(copts, C.GIT_PROXY_OPTIONS_VERSION) if opts == nil { - return + return nil } - ptr._type = C.git_proxy_t(opts.Type) - ptr.url = C.CString(opts.Url) + copts._type = C.git_proxy_t(opts.Type) + copts.url = C.CString(opts.Url) + return copts } -func freeProxyOptions(ptr *C.git_proxy_options) { - if ptr == nil { +func freeProxyOptions(copts *C.git_proxy_options) { + if copts == nil { return } - C.free(unsafe.Pointer(ptr.url)) + C.free(unsafe.Pointer(copts.url)) } // RemoteIsValidName returns whether the remote name is well-formed. @@ -738,35 +794,54 @@ func (o *Remote) RefspecCount() uint { return uint(count) } -func populateFetchOptions(options *C.git_fetch_options, opts *FetchOptions) { - C.git_fetch_options_init(options, C.GIT_FETCH_OPTIONS_VERSION) +func populateFetchOptions(copts *C.git_fetch_options, opts *FetchOptions, errorTarget *error) *C.git_fetch_options { + C.git_fetch_options_init(copts, C.GIT_FETCH_OPTIONS_VERSION) if opts == nil { - return + return nil } - populateRemoteCallbacks(&options.callbacks, &opts.RemoteCallbacks) - options.prune = C.git_fetch_prune_t(opts.Prune) - options.update_fetchhead = cbool(opts.UpdateFetchhead) - options.download_tags = C.git_remote_autotag_option_t(opts.DownloadTags) + populateRemoteCallbacks(&copts.callbacks, &opts.RemoteCallbacks, errorTarget) + copts.prune = C.git_fetch_prune_t(opts.Prune) + copts.update_fetchhead = cbool(opts.UpdateFetchhead) + copts.download_tags = C.git_remote_autotag_option_t(opts.DownloadTags) - options.custom_headers = C.git_strarray{} - options.custom_headers.count = C.size_t(len(opts.Headers)) - options.custom_headers.strings = makeCStringsFromStrings(opts.Headers) - populateProxyOptions(&options.proxy_opts, &opts.ProxyOptions) + copts.custom_headers = C.git_strarray{ + count: C.size_t(len(opts.Headers)), + strings: makeCStringsFromStrings(opts.Headers), + } + populateProxyOptions(&copts.proxy_opts, &opts.ProxyOptions) + return copts } -func populatePushOptions(options *C.git_push_options, opts *PushOptions) { - C.git_push_options_init(options, C.GIT_PUSH_OPTIONS_VERSION) - if opts == nil { +func freeFetchOptions(copts *C.git_fetch_options) { + if copts == nil { return } + freeStrarray(&copts.custom_headers) + untrackCallbacksPayload(&copts.callbacks) + freeProxyOptions(&copts.proxy_opts) +} - options.pb_parallelism = C.uint(opts.PbParallelism) +func populatePushOptions(copts *C.git_push_options, opts *PushOptions, errorTarget *error) *C.git_push_options { + C.git_push_options_init(copts, C.GIT_PUSH_OPTIONS_VERSION) + if opts == nil { + return nil + } - options.custom_headers = C.git_strarray{} - options.custom_headers.count = C.size_t(len(opts.Headers)) - options.custom_headers.strings = makeCStringsFromStrings(opts.Headers) + copts.pb_parallelism = C.uint(opts.PbParallelism) + copts.custom_headers = C.git_strarray{ + count: C.size_t(len(opts.Headers)), + strings: makeCStringsFromStrings(opts.Headers), + } + populateRemoteCallbacks(&copts.callbacks, &opts.RemoteCallbacks, errorTarget) + return copts +} - populateRemoteCallbacks(&options.callbacks, &opts.RemoteCallbacks) +func freePushOptions(copts *C.git_push_options) { + if copts == nil { + return + } + untrackCallbacksPayload(&copts.callbacks) + freeStrarray(&copts.custom_headers) } // Fetch performs a fetch operation. refspecs specifies which refspecs @@ -780,26 +855,29 @@ func (o *Remote) Fetch(refspecs []string, opts *FetchOptions, msg string) error defer C.free(unsafe.Pointer(cmsg)) } - crefspecs := C.git_strarray{} - crefspecs.count = C.size_t(len(refspecs)) - crefspecs.strings = makeCStringsFromStrings(refspecs) + var err error + crefspecs := C.git_strarray{ + count: C.size_t(len(refspecs)), + strings: makeCStringsFromStrings(refspecs), + } defer freeStrarray(&crefspecs) - coptions := (*C.git_fetch_options)(C.calloc(1, C.size_t(unsafe.Sizeof(C.git_fetch_options{})))) - defer C.free(unsafe.Pointer(coptions)) - - populateFetchOptions(coptions, opts) - defer untrackCalbacksPayload(&coptions.callbacks) - defer freeStrarray(&coptions.custom_headers) + coptions := populateFetchOptions(&C.git_fetch_options{}, opts, &err) + defer freeFetchOptions(coptions) runtime.LockOSThread() defer runtime.UnlockOSThread() ret := C.git_remote_fetch(o.ptr, &crefspecs, coptions, cmsg) runtime.KeepAlive(o) + + if ret == C.int(ErrorCodeUser) && err != nil { + return err + } if ret < 0 { return MakeGitError(ret) } + return nil } @@ -819,23 +897,27 @@ func (o *Remote) ConnectPush(callbacks *RemoteCallbacks, proxyOpts *ProxyOptions // // 'headers' are extra HTTP headers to use in this connection. func (o *Remote) Connect(direction ConnectDirection, callbacks *RemoteCallbacks, proxyOpts *ProxyOptions, headers []string) error { - var ccallbacks C.git_remote_callbacks - populateRemoteCallbacks(&ccallbacks, callbacks) + var err error + ccallbacks := populateRemoteCallbacks(&C.git_remote_callbacks{}, callbacks, &err) + defer untrackCallbacksPayload(ccallbacks) - var cproxy C.git_proxy_options - populateProxyOptions(&cproxy, proxyOpts) - defer freeProxyOptions(&cproxy) + cproxy := populateProxyOptions(&C.git_proxy_options{}, proxyOpts) + defer freeProxyOptions(cproxy) - cheaders := C.git_strarray{} - cheaders.count = C.size_t(len(headers)) - cheaders.strings = makeCStringsFromStrings(headers) + cheaders := C.git_strarray{ + count: C.size_t(len(headers)), + strings: makeCStringsFromStrings(headers), + } defer freeStrarray(&cheaders) runtime.LockOSThread() defer runtime.UnlockOSThread() - ret := C.git_remote_connect(o.ptr, C.git_direction(direction), &ccallbacks, &cproxy, &cheaders) + ret := C.git_remote_connect(o.ptr, C.git_direction(direction), ccallbacks, cproxy, &cheaders) runtime.KeepAlive(o) + if ret == C.int(ErrorCodeUser) && err != nil { + return err + } if ret != 0 { return MakeGitError(ret) } @@ -899,23 +981,24 @@ func (o *Remote) Ls(filterRefs ...string) ([]RemoteHead, error) { } func (o *Remote) Push(refspecs []string, opts *PushOptions) error { - crefspecs := C.git_strarray{} - crefspecs.count = C.size_t(len(refspecs)) - crefspecs.strings = makeCStringsFromStrings(refspecs) + crefspecs := C.git_strarray{ + count: C.size_t(len(refspecs)), + strings: makeCStringsFromStrings(refspecs), + } defer freeStrarray(&crefspecs) - coptions := (*C.git_push_options)(C.calloc(1, C.size_t(unsafe.Sizeof(C.git_push_options{})))) - defer C.free(unsafe.Pointer(coptions)) - - populatePushOptions(coptions, opts) - defer untrackCalbacksPayload(&coptions.callbacks) - defer freeStrarray(&coptions.custom_headers) + var err error + coptions := populatePushOptions(&C.git_push_options{}, opts, &err) + defer freePushOptions(coptions) runtime.LockOSThread() defer runtime.UnlockOSThread() ret := C.git_remote_push(o.ptr, &crefspecs, coptions) runtime.KeepAlive(o) + if ret == C.int(ErrorCodeUser) && err != nil { + return err + } if ret < 0 { return MakeGitError(ret) } @@ -927,14 +1010,18 @@ func (o *Remote) PruneRefs() bool { } func (o *Remote) Prune(callbacks *RemoteCallbacks) error { - var ccallbacks C.git_remote_callbacks - populateRemoteCallbacks(&ccallbacks, callbacks) + var err error + ccallbacks := populateRemoteCallbacks(&C.git_remote_callbacks{}, callbacks, &err) + defer untrackCallbacksPayload(ccallbacks) runtime.LockOSThread() defer runtime.UnlockOSThread() - ret := C.git_remote_prune(o.ptr, &ccallbacks) + ret := C.git_remote_prune(o.ptr, ccallbacks) runtime.KeepAlive(o) + if ret == C.int(ErrorCodeUser) && err != nil { + return err + } if ret < 0 { return MakeGitError(ret) } |
