summaryrefslogtreecommitdiff
path: root/lgc/download.go
diff options
context:
space:
mode:
Diffstat (limited to 'lgc/download.go')
-rw-r--r--lgc/download.go373
1 files changed, 373 insertions, 0 deletions
diff --git a/lgc/download.go b/lgc/download.go
new file mode 100644
index 0000000..71ed647
--- /dev/null
+++ b/lgc/download.go
@@ -0,0 +1,373 @@
+/*
+ * This file is part of Go Responsiveness.
+ *
+ * Go Responsiveness is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU General Public License as published by the Free Software Foundation,
+ * either version 2 of the License, or (at your option) any later version.
+ * Go Responsiveness is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
+ * PARTICULAR PURPOSE. See the GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with Go Responsiveness. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package lgc
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptrace"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/network-quality/goresponsiveness/debug"
+ "github.com/network-quality/goresponsiveness/stats"
+ "github.com/network-quality/goresponsiveness/traceable"
+ "github.com/network-quality/goresponsiveness/utilities"
+)
+
+// TODO: All 64-bit fields that are accessed atomically must
+// appear at the top of this struct.
+type LoadGeneratingConnectionDownload struct {
+ downloaded uint64
+ lastIntervalEnd int64
+ ConnectToAddr string
+ URL string
+ downloadStartTime time.Time
+ lastDownloaded uint64
+ client *http.Client
+ debug debug.DebugLevel
+ InsecureSkipVerify bool
+ KeyLogger io.Writer
+ clientId uint64
+ tracer *httptrace.ClientTrace
+ stats stats.TraceStats
+ status LgcStatus
+ statusLock *sync.Mutex
+ statusWaiter *sync.Cond
+}
+
+func NewLoadGeneratingConnectionDownload(url string, keyLogger io.Writer, connectToAddr string, insecureSkipVerify bool) LoadGeneratingConnectionDownload {
+ lgd := LoadGeneratingConnectionDownload{
+ URL: url,
+ KeyLogger: keyLogger,
+ ConnectToAddr: connectToAddr,
+ InsecureSkipVerify: insecureSkipVerify,
+ statusLock: &sync.Mutex{},
+ }
+ lgd.statusWaiter = sync.NewCond(lgd.statusLock)
+ return lgd
+}
+
+func (lgd *LoadGeneratingConnectionDownload) WaitUntilStarted(ctxt context.Context) bool {
+ conditional := func() bool { return lgd.status != LGC_STATUS_NOT_STARTED }
+ go utilities.ContextSignaler(ctxt, 500*time.Millisecond, &conditional, lgd.statusWaiter)
+ return utilities.WaitWithContext(ctxt, &conditional, lgd.statusLock, lgd.statusWaiter)
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetDnsStartTimeInfo(
+ now time.Time,
+ dnsStartInfo httptrace.DNSStartInfo,
+) {
+ lgd.stats.DnsStartTime = now
+ lgd.stats.DnsStart = dnsStartInfo
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "DNS Start for %v: %v\n",
+ lgd.ClientId(),
+ dnsStartInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetDnsDoneTimeInfo(
+ now time.Time,
+ dnsDoneInfo httptrace.DNSDoneInfo,
+) {
+ lgd.stats.DnsDoneTime = now
+ lgd.stats.DnsDone = dnsDoneInfo
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "DNS Done for %v: %v\n",
+ lgd.ClientId(),
+ lgd.stats.DnsDone,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetConnectStartTime(
+ now time.Time,
+) {
+ lgd.stats.ConnectStartTime = now
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "TCP Start for %v at %v\n",
+ lgd.ClientId(),
+ lgd.stats.ConnectStartTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetConnectDoneTimeError(
+ now time.Time,
+ err error,
+) {
+ lgd.stats.ConnectDoneTime = now
+ lgd.stats.ConnectDoneError = err
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "TCP Done for %v (with error %v) @ %v\n",
+ lgd.ClientId(),
+ lgd.stats.ConnectDoneError,
+ lgd.stats.ConnectDoneTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetGetConnTime(now time.Time) {
+ lgd.stats.GetConnectionStartTime = now
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Started getting connection for %v @ %v\n",
+ lgd.ClientId(),
+ lgd.stats.GetConnectionStartTime,
+ )
+ }
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_RUNNING
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetGotConnTimeInfo(
+ now time.Time,
+ gotConnInfo httptrace.GotConnInfo,
+) {
+ if gotConnInfo.Reused {
+ fmt.Printf("Unexpectedly reusing a connection!\n")
+ panic(!gotConnInfo.Reused)
+ }
+ lgd.stats.GetConnectionDoneTime = now
+ lgd.stats.ConnInfo = gotConnInfo
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Got connection for %v at %v with info %v\n",
+ lgd.ClientId(),
+ lgd.stats.GetConnectionDoneTime,
+ lgd.stats.ConnInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetTLSHandshakeStartTime(
+ now time.Time,
+) {
+ lgd.stats.TLSStartTime = utilities.Some(now)
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Started TLS Handshake for %v @ %v\n",
+ lgd.ClientId(),
+ lgd.stats.TLSStartTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetTLSHandshakeDoneTimeState(
+ now time.Time,
+ connectionState tls.ConnectionState,
+) {
+ lgd.stats.TLSDoneTime = utilities.Some(now)
+ lgd.stats.TLSConnInfo = connectionState
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Completed TLS handshake for %v at %v with info %v\n",
+ lgd.ClientId(),
+ lgd.stats.TLSDoneTime,
+ lgd.stats.TLSConnInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetHttpWroteRequestTimeInfo(
+ now time.Time,
+ info httptrace.WroteRequestInfo,
+) {
+ lgd.stats.HttpWroteRequestTime = now
+ lgd.stats.HttpInfo = info
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "(lgd) Http finished writing request for %v at %v with info %v\n",
+ lgd.ClientId(),
+ lgd.stats.HttpWroteRequestTime,
+ lgd.stats.HttpInfo,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) SetHttpResponseReadyTime(
+ now time.Time,
+) {
+ lgd.stats.HttpResponseReadyTime = now
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Got the first byte of HTTP response headers for %v at %v\n",
+ lgd.ClientId(),
+ lgd.stats.HttpResponseReadyTime,
+ )
+ }
+}
+
+func (lgd *LoadGeneratingConnectionDownload) ClientId() uint64 {
+ return lgd.clientId
+}
+
+func (lgd *LoadGeneratingConnectionDownload) TransferredInInterval() (uint64, time.Duration) {
+ transferred := atomic.SwapUint64(&lgd.downloaded, 0)
+ newIntervalEnd := (time.Now().Sub(lgd.downloadStartTime)).Nanoseconds()
+ previousIntervalEnd := atomic.SwapInt64(&lgd.lastIntervalEnd, newIntervalEnd)
+ intervalLength := time.Duration(newIntervalEnd - previousIntervalEnd)
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf("download: Transferred: %v bytes in %v.\n", transferred, intervalLength)
+ }
+ return transferred, intervalLength
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Client() *http.Client {
+ return lgd.client
+}
+
+type countingReader struct {
+ n *uint64
+ ctx context.Context
+ readable io.Reader
+}
+
+func (cr *countingReader) Read(p []byte) (n int, err error) {
+ if cr.ctx.Err() != nil {
+ return 0, io.EOF
+ }
+
+ n, err = cr.readable.Read(p)
+ atomic.AddUint64(cr.n, uint64(n))
+ return
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Start(
+ parentCtx context.Context,
+ debugLevel debug.DebugLevel,
+) bool {
+ lgd.downloaded = 0
+ lgd.debug = debugLevel
+ lgd.clientId = utilities.GenerateUniqueId()
+
+ transport := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: lgd.InsecureSkipVerify,
+ },
+ }
+
+ if !utilities.IsInterfaceNil(lgd.KeyLogger) {
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Using an SSL Key Logger for this load-generating download.\n",
+ )
+ }
+
+ // The presence of a custom TLSClientConfig in a *generic* `transport`
+ // means that go will default to HTTP/1.1 and cowardly avoid HTTP/2:
+ // https://github.com/golang/go/blob/7ca6902c171b336d98adbb103d701a013229c806/src/net/http/transport.go#L278
+ // Also, it would appear that the API's choice of HTTP vs HTTP2 can
+ // depend on whether the url contains
+ // https:// or http://:
+ // https://github.com/golang/go/blob/7ca6902c171b336d98adbb103d701a013229c806/src/net/http/transport.go#L74
+ transport.TLSClientConfig.KeyLogWriter = lgd.KeyLogger
+ }
+ transport.TLSClientConfig.InsecureSkipVerify = lgd.InsecureSkipVerify
+
+ utilities.OverrideHostTransport(transport, lgd.ConnectToAddr)
+
+ lgd.client = &http.Client{Transport: transport}
+ lgd.tracer = traceable.GenerateHttpTimingTracer(lgd, lgd.debug)
+
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf(
+ "Started a load-generating download (id: %v).\n",
+ lgd.clientId,
+ )
+ }
+
+ go lgd.doDownload(parentCtx)
+ return true
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Status() LgcStatus {
+ return lgd.status
+}
+
+func (lgd *LoadGeneratingConnectionDownload) Stats() *stats.TraceStats {
+ return &lgd.stats
+}
+
+func (lgd *LoadGeneratingConnectionDownload) doDownload(ctx context.Context) error {
+ var request *http.Request = nil
+ var get *http.Response = nil
+ var err error = nil
+
+ if request, err = http.NewRequestWithContext(
+ httptrace.WithClientTrace(ctx, lgd.tracer),
+ "GET",
+ lgd.URL,
+ nil,
+ ); err != nil {
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_ERROR
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+ return err
+ }
+
+ // Used to disable compression
+ request.Header.Set("Accept-Encoding", "identity")
+ request.Header.Set("User-Agent", utilities.UserAgent())
+
+ lgd.downloadStartTime = time.Now()
+ lgd.lastIntervalEnd = 0
+
+ if get, err = lgd.client.Do(request); err != nil {
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_ERROR
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+ return err
+ }
+
+ // Header.Get returns "" when not set
+ if get.Header.Get("Content-Encoding") != "" {
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_ERROR
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+ fmt.Printf("Content-Encoding header was set (compression not allowed)")
+ return fmt.Errorf("Content-Encoding header was set (compression not allowed)")
+ }
+ cr := &countingReader{n: &lgd.downloaded, ctx: ctx, readable: get.Body}
+ _, _ = io.Copy(io.Discard, cr)
+
+ lgd.statusLock.Lock()
+ lgd.status = LGC_STATUS_DONE
+ lgd.statusWaiter.Broadcast()
+ lgd.statusLock.Unlock()
+
+ get.Body.Close()
+ if debug.IsDebug(lgd.debug) {
+ fmt.Printf("Ending a load-generating download.\n")
+ }
+
+ return nil
+}