summaryrefslogtreecommitdiff
path: root/lgc/download.go
diff options
context:
space:
mode:
authorWill Hawkins <[email protected]>2023-04-22 01:27:59 -0400
committerWill Hawkins <[email protected]>2023-04-22 01:27:59 -0400
commitd5ec3aba77624387711ffa90e6960e406e9790e6 (patch)
tree6c32da06f26e9abc0bc1821b2457a5f32d5cb098 /lgc/download.go
parentb2e528e07842488e573aefe783f1da755f818ffa (diff)
[Refactor] Move components into separate packages
A long-overdue change to split certain packages once smashed into the RPM package into their own package. The resulting code should make it easier for people to navigate the source code. In the process, fixed a bug where a self probe being started on a load-generating connection races with the establishment of the load-generating connection and causes a panic because the self probe is not establishing a connection on an already established connection.
Diffstat (limited to 'lgc/download.go')
-rw-r--r--lgc/download.go373
1 files changed, 373 insertions, 0 deletions
diff --git a/lgc/download.go b/lgc/download.go
new file mode 100644
index 0000000..71ed647
--- /dev/null
+++ b/lgc/download.go
@@ -0,0 +1,373 @@
+/*
+ * This file is part of Go Responsiveness.
+ *
+ * Go Responsiveness is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU General Public License as published by the Free Software Foundation,
+ * either version 2 of the License, or (at your option) any later version.
+ * Go Responsiveness is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
+ * PARTICULAR PURPOSE. See the GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with Go Responsiveness. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package lgc
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptrace"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/network-quality/goresponsiveness/debug"
+ "github.com/network-quality/goresponsiveness/stats"
+ "github.com/network-quality/goresponsiveness/traceable"
+ "github.com/network-quality/goresponsiveness/utilities"
+)
+
+// TODO: All 64-bit fields that are accessed atomically must
+// appear at the top of this struct.
+type LoadGeneratingConnectionDownload struct {
+ downloaded uint64
+ lastIntervalEnd int64
+ ConnectToAddr string
+ URL string
+ downloadStartTime time.Time
+ lastDownloaded uint64
+ client *http.Client
+ debug debug.DebugLevel
+ InsecureSkipVerify bool
+ KeyLogger io.Writer
+ clientId uint64
+ tracer *httptrace.ClientTrace
+ stats stats.TraceStats
+ status LgcStatus
+ statusLock *sync.Mutex
+ statusWaiter *sync.Cond
+}
+
+func NewLoadGeneratingConnectionDownload(url string, keyLogger io.Writer, connectToAddr string, insecureSkipVerify bool) LoadGeneratingConnectionDownload {
+ lgd := LoadGeneratingConnectionDownload{
+ URL: url,
+ KeyLogger: keyLogger,
+ ConnectToAddr: connectToAddr,
+ InsecureSkipVerify: insecureSkipVerify,
+ statusLock: &sync.Mutex{},
+ }
+ lgd.statusWaiter = sync.NewCond(lgd.statusLock)
+ return lgd
+}
+
+func (lgd *LoadGeneratingConnectionDownload) WaitUntilStarted(ctxt context.Context) bool {
+ conditional := func() bool { return lgd.status != LGC_STATUS_NOT_STARTED }
+ go utilities.ContextSignaler(ctxt, 500*time.Millisecond, &conditional, lgd.statusWaiter)
+ return utilities.WaitWithContext(ctxt, &conditional, lgd.statusLock, lgd.statusWaiter)
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetDnsStartTimeInfo(
+ now time.Time,
+ dnsStartInfo httptrace.DNSStartInfo,
+) {
+ lgd.stats.DnsStartTime = now
+ lgd.stats.DnsStart = dnsStartInfo
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "DNS Start for %v: %v\n",
+ lgd.ClientId(),
+ dnsStartInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetDnsDoneTimeInfo(
+ now time.Time,
+ dnsDoneInfo httptrace.DNSDoneInfo,
+) {
+ lgd.stats.DnsDoneTime = now
+ lgd.stats.DnsDone = dnsDoneInfo
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "DNS Done for %v: %v\n",
+ lgd.ClientId(),
+ lgd.stats.DnsDone,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetConnectStartTime(
+ now time.Time,
+) {
+ lgd.stats.ConnectStartTime = now
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "TCP Start for %v at %v\n",
+ lgd.ClientId(),
+ lgd.stats.ConnectStartTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetConnectDoneTimeError(
+ now time.Time,
+ err error,
+) {
+ lgd.stats.ConnectDoneTime = now
+ lgd.stats.ConnectDoneError = err
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "TCP Done for %v (with error %v) @ %v\n",
+ lgd.ClientId(),
+ lgd.stats.ConnectDoneError,
+ lgd.stats.ConnectDoneTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetGetConnTime(now time.Time) {
+ lgd.stats.GetConnectionStartTime = now
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Started getting connection for %v @ %v\n",
+ lgd.ClientId(),
+ lgd.stats.GetConnectionStartTime,
+ )
+ }
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_RUNNING
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetGotConnTimeInfo(
+ now time.Time,
+ gotConnInfo httptrace.GotConnInfo,
+) {
+ if gotConnInfo.Reused {
+ fmt.Printf("Unexpectedly reusing a connection!\n")
+ panic(!gotConnInfo.Reused)
+ }
+ lgd.stats.GetConnectionDoneTime = now
+ lgd.stats.ConnInfo = gotConnInfo
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Got connection for %v at %v with info %v\n",
+ lgd.ClientId(),
+ lgd.stats.GetConnectionDoneTime,
+ lgd.stats.ConnInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetTLSHandshakeStartTime(
+ now time.Time,
+) {
+ lgd.stats.TLSStartTime = utilities.Some(now)
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Started TLS Handshake for %v @ %v\n",
+ lgd.ClientId(),
+ lgd.stats.TLSStartTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetTLSHandshakeDoneTimeState(
+ now time.Time,
+ connectionState tls.ConnectionState,
+) {
+ lgd.stats.TLSDoneTime = utilities.Some(now)
+ lgd.stats.TLSConnInfo = connectionState
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Completed TLS handshake for %v at %v with info %v\n",
+ lgd.ClientId(),
+ lgd.stats.TLSDoneTime,
+ lgd.stats.TLSConnInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetHttpWroteRequestTimeInfo(
+ now time.Time,
+ info httptrace.WroteRequestInfo,
+) {
+ lgd.stats.HttpWroteRequestTime = now
+ lgd.stats.HttpInfo = info
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "(lgd) Http finished writing request for %v at %v with info %v\n",
+ lgd.ClientId(),
+ lgd.stats.HttpWroteRequestTime,
+ lgd.stats.HttpInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetHttpResponseReadyTime(
+ now time.Time,
+) {
+ lgd.stats.HttpResponseReadyTime = now
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Got the first byte of HTTP response headers for %v at %v\n",
+ lgd.ClientId(),
+ lgd.stats.HttpResponseReadyTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) ClientId() uint64 {
+ return lgd.clientId
+}
+
+func (lgd *LoadGeneratingConnectionDownload) TransferredInInterval() (uint64, time.Duration) {
+ transferred := atomic.SwapUint64(&lgd.downloaded, 0)
+ newIntervalEnd := (time.Now().Sub(lgd.downloadStartTime)).Nanoseconds()
+ previousIntervalEnd := atomic.SwapInt64(&lgd.lastIntervalEnd, newIntervalEnd)
+ intervalLength := time.Duration(newIntervalEnd - previousIntervalEnd)
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf("download: Transferred: %v bytes in %v.\n", transferred, intervalLength)
+ }
+ return transferred, intervalLength
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Client() *http.Client {
+ return lgd.client
+}
+
+type countingReader struct {
+ n *uint64
+ ctx context.Context
+ readable io.Reader
+}
+
+func (cr *countingReader) Read(p []byte) (n int, err error) {
+ if cr.ctx.Err() != nil {
+ return 0, io.EOF
+ }
+
+ n, err = cr.readable.Read(p)
+ atomic.AddUint64(cr.n, uint64(n))
+ return
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Start(
+ parentCtx context.Context,
+ debugLevel debug.DebugLevel,
+) bool {
+ lgd.downloaded = 0
+ lgd.debug = debugLevel
+ lgd.clientId = utilities.GenerateUniqueId()
+
+ transport := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: lgd.InsecureSkipVerify,
+ },
+ }
+
+ if !utilities.IsInterfaceNil(lgd.KeyLogger) {
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Using an SSL Key Logger for this load-generating download.\n",
+ )
+ }
+
+ // The presence of a custom TLSClientConfig in a *generic* `transport`
+ // means that go will default to HTTP/1.1 and cowardly avoid HTTP/2:
+ // https://github.com/golang/go/blob/7ca6902c171b336d98adbb103d701a013229c806/src/net/http/transport.go#L278
+ // Also, it would appear that the API's choice of HTTP vs HTTP2 can
+ // depend on whether the url contains
+ // https:// or http://:
+ // https://github.com/golang/go/blob/7ca6902c171b336d98adbb103d701a013229c806/src/net/http/transport.go#L74
+ transport.TLSClientConfig.KeyLogWriter = lgd.KeyLogger
+ }
+ transport.TLSClientConfig.InsecureSkipVerify = lgd.InsecureSkipVerify
+
+ utilities.OverrideHostTransport(transport, lgd.ConnectToAddr)
+
+ lgd.client = &http.Client{Transport: transport}
+ lgd.tracer = traceable.GenerateHttpTimingTracer(lgd, lgd.debug)
+
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Started a load-generating download (id: %v).\n",
+ lgd.clientId,
+ )
+ }
+
+ go lgd.doDownload(parentCtx)
+ return true
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Status() LgcStatus {
+ return lgd.status
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Stats() *stats.TraceStats {
+ return &lgd.stats
+}
+
+func (lgd *LoadGeneratingConnectionDownload) doDownload(ctx context.Context) error {
+ var request *http.Request = nil
+ var get *http.Response = nil
+ var err error = nil
+
+ if request, err = http.NewRequestWithContext(
+ httptrace.WithClientTrace(ctx, lgd.tracer),
+ "GET",
+ lgd.URL,
+ nil,
+ ); err != nil {
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_ERROR
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+ return err
+ }
+
+ // Used to disable compression
+ request.Header.Set("Accept-Encoding", "identity")
+ request.Header.Set("User-Agent", utilities.UserAgent())
+
+ lgd.downloadStartTime = time.Now()
+ lgd.lastIntervalEnd = 0
+
+ if get, err = lgd.client.Do(request); err != nil {
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_ERROR
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+ return err
+ }
+
+ // Header.Get returns "" when not set
+ if get.Header.Get("Content-Encoding") != "" {
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_ERROR
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+ fmt.Printf("Content-Encoding header was set (compression not allowed)")
+ return fmt.Errorf("Content-Encoding header was set (compression not allowed)")
+ }
+ cr := &countingReader{n: &lgd.downloaded, ctx: ctx, readable: get.Body}
+ _, _ = io.Copy(io.Discard, cr)
+
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_DONE
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+
+ get.Body.Close()
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf("Ending a load-generating download.\n")
+ }
+
+ return nil
+}