Sfoglia il codice sorgente

support https with certificate (#423)

Support https with server certificate.
Margaret Nolan 6 anni fa
parent
commit
b44fece901

+ 20
- 3
doc/README.md Vedi File

@@ -256,9 +256,10 @@ total of the absolute value of all samples when aggregated at the address level.
256 256
 
257 257
 # Fetching profiles
258 258
 
259
-pprof can read profiles from a file or directly from a URL over http. Its native
260
-format is a gzipped profile.proto file, but it can also accept some legacy
261
-formats generated by [gperftools](https://github.com/gperftools/gperftools).
259
+pprof can read profiles from a file or directly from a URL over http or https.
260
+Its native format is a gzipped profile.proto file, but it can
261
+also accept some legacy formats generated by 
262
+[gperftools](https://github.com/gperftools/gperftools).
262 263
 
263 264
 When fetching from a URL handler, pprof accepts options to indicate how much to
264 265
 wait for the profile.
@@ -270,6 +271,22 @@ wait for the profile.
270 271
   profile over http. If not specified, pprof will use heuristics to determine a
271 272
   reasonable timeout.
272 273
 
274
+pprof also accepts options which allow a user to specify TLS certificates to
275
+use when fetching or symbolizing a profile from a protected endpoint. For more
276
+information about generating these certificates, see
277
+https://docs.docker.com/engine/security/https/.
278
+
279
+* **-tls\_cert= _/path/to/cert_:** File containing the TLS client certificate
280
+  to be used when fetching and symbolizing profiles.
281
+* **-tls\_key= _/path/to/key_:** File containing the TLS private key to be used
282
+  when fetching and symbolizing profiles.
283
+* **-tls\_ca= _/path/to/ca_:** File containing the certificate authority to be
284
+  used when fetching and symbolizing profiles.
285
+
286
+pprof also supports skipping verification of the server's certificate chain and
287
+host name when collecting or symbolizing a profile. To skip this verification, 
288
+use "https+insecure" in place of "https" in the URL.
289
+
273 290
 If multiple profiles are specified, pprof will fetch them all and merge
274 291
 them. This is useful to combine profiles from multiple processes of a
275 292
 distributed job. The profiles may be from different programs but must be

+ 25
- 18
driver/driver.go Vedi File

@@ -17,6 +17,7 @@ package driver
17 17
 
18 18
 import (
19 19
 	"io"
20
+	"net/http"
20 21
 	"regexp"
21 22
 	"time"
22 23
 
@@ -48,13 +49,14 @@ func (o *Options) internalOptions() *plugin.Options {
48 49
 		}
49 50
 	}
50 51
 	return &plugin.Options{
51
-		Writer:     o.Writer,
52
-		Flagset:    o.Flagset,
53
-		Fetch:      o.Fetch,
54
-		Sym:        sym,
55
-		Obj:        obj,
56
-		UI:         o.UI,
57
-		HTTPServer: httpServer,
52
+		Writer:        o.Writer,
53
+		Flagset:       o.Flagset,
54
+		Fetch:         o.Fetch,
55
+		Sym:           sym,
56
+		Obj:           obj,
57
+		UI:            o.UI,
58
+		HTTPServer:    httpServer,
59
+		HTTPTransport: o.HTTPTransport,
58 60
 	}
59 61
 }
60 62
 
@@ -64,13 +66,14 @@ type HTTPServerArgs plugin.HTTPServerArgs
64 66
 
65 67
 // Options groups all the optional plugins into pprof.
66 68
 type Options struct {
67
-	Writer     Writer
68
-	Flagset    FlagSet
69
-	Fetch      Fetcher
70
-	Sym        Symbolizer
71
-	Obj        ObjTool
72
-	UI         UI
73
-	HTTPServer func(*HTTPServerArgs) error
69
+	Writer        Writer
70
+	Flagset       FlagSet
71
+	Fetch         Fetcher
72
+	Sym           Symbolizer
73
+	Obj           ObjTool
74
+	UI            UI
75
+	HTTPServer    func(*HTTPServerArgs) error
76
+	HTTPTransport http.RoundTripper
74 77
 }
75 78
 
76 79
 // Writer provides a mechanism to write data under a certain name,
@@ -100,12 +103,16 @@ type FlagSet interface {
100 103
 	// single flag
101 104
 	StringList(name string, def string, usage string) *[]*string
102 105
 
103
-	// ExtraUsage returns any additional text that should be
104
-	// printed after the standard usage message.
105
-	// The typical use of ExtraUsage is to show any custom flags
106
-	// defined by the specific pprof plugins being used.
106
+	// ExtraUsage returns any additional text that should be printed after the
107
+	// standard usage message. The extra usage message returned includes all text
108
+	// added with AddExtraUsage().
109
+	// The typical use of ExtraUsage is to show any custom flags defined by the
110
+	// specific pprof plugins being used.
107 111
 	ExtraUsage() string
108 112
 
113
+	// AddExtraUsage appends additional text to the end of the extra usage message.
114
+	AddExtraUsage(eu string)
115
+
109 116
 	// Parse initializes the flags with their values for this run
110 117
 	// and returns the non-flag command line arguments.
111 118
 	// If an unknown flag is encountered or there are no arguments,

+ 8
- 4
internal/driver/driver_test.go Vedi File

@@ -130,8 +130,9 @@ func TestParse(t *testing.T) {
130 130
 			}
131 131
 
132 132
 			// First pprof invocation to save the profile into a profile.proto.
133
-			o1 := setDefaults(nil)
134
-			o1.Flagset = f
133
+			// Pass in flag set hen setting defaults, because otherwise default
134
+			// transport will try to add flags to the default flag set.
135
+			o1 := setDefaults(&plugin.Options{Flagset: f})
135 136
 			o1.Fetch = testFetcher{}
136 137
 			o1.Sym = testSymbolizer{}
137 138
 			o1.UI = testUI
@@ -167,8 +168,9 @@ func TestParse(t *testing.T) {
167 168
 
168 169
 			// Second pprof invocation to read the profile from profile.proto
169 170
 			// and generate a report.
170
-			o2 := setDefaults(nil)
171
-			o2.Flagset = f
171
+			// Pass in flag set hen setting defaults, because otherwise default
172
+			// transport will try to add flags to the default flag set.
173
+			o2 := setDefaults(&plugin.Options{Flagset: f})
172 174
 			o2.Sym = testSymbolizeDemangler{}
173 175
 			o2.Obj = new(mockObjTool)
174 176
 			o2.UI = testUI
@@ -297,6 +299,8 @@ type testFlags struct {
297 299
 
298 300
 func (testFlags) ExtraUsage() string { return "" }
299 301
 
302
+func (testFlags) AddExtraUsage(eu string) {}
303
+
300 304
 func (f testFlags) Bool(s string, d bool, c string) *bool {
301 305
 	if b, ok := f.bools[s]; ok {
302 306
 		return &b

+ 18
- 42
internal/driver/fetch.go Vedi File

@@ -16,7 +16,6 @@ package driver
16 16
 
17 17
 import (
18 18
 	"bytes"
19
-	"crypto/tls"
20 19
 	"fmt"
21 20
 	"io"
22 21
 	"io/ioutil"
@@ -57,7 +56,7 @@ func fetchProfiles(s *source, o *plugin.Options) (*profile.Profile, error) {
57 56
 		})
58 57
 	}
59 58
 
60
-	p, pbase, m, mbase, save, err := grabSourcesAndBases(sources, bases, o.Fetch, o.Obj, o.UI)
59
+	p, pbase, m, mbase, save, err := grabSourcesAndBases(sources, bases, o.Fetch, o.Obj, o.UI, o.HTTPTransport)
61 60
 	if err != nil {
62 61
 		return nil, err
63 62
 	}
@@ -123,7 +122,7 @@ func fetchProfiles(s *source, o *plugin.Options) (*profile.Profile, error) {
123 122
 	return p, nil
124 123
 }
125 124
 
126
-func grabSourcesAndBases(sources, bases []profileSource, fetch plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI) (*profile.Profile, *profile.Profile, plugin.MappingSources, plugin.MappingSources, bool, error) {
125
+func grabSourcesAndBases(sources, bases []profileSource, fetch plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI, tr http.RoundTripper) (*profile.Profile, *profile.Profile, plugin.MappingSources, plugin.MappingSources, bool, error) {
127 126
 	wg := sync.WaitGroup{}
128 127
 	wg.Add(2)
129 128
 	var psrc, pbase *profile.Profile
@@ -133,11 +132,11 @@ func grabSourcesAndBases(sources, bases []profileSource, fetch plugin.Fetcher, o
133 132
 	var countsrc, countbase int
134 133
 	go func() {
135 134
 		defer wg.Done()
136
-		psrc, msrc, savesrc, countsrc, errsrc = chunkedGrab(sources, fetch, obj, ui)
135
+		psrc, msrc, savesrc, countsrc, errsrc = chunkedGrab(sources, fetch, obj, ui, tr)
137 136
 	}()
138 137
 	go func() {
139 138
 		defer wg.Done()
140
-		pbase, mbase, savebase, countbase, errbase = chunkedGrab(bases, fetch, obj, ui)
139
+		pbase, mbase, savebase, countbase, errbase = chunkedGrab(bases, fetch, obj, ui, tr)
141 140
 	}()
142 141
 	wg.Wait()
143 142
 	save := savesrc || savebase
@@ -167,7 +166,7 @@ func grabSourcesAndBases(sources, bases []profileSource, fetch plugin.Fetcher, o
167 166
 // chunkedGrab fetches the profiles described in source and merges them into
168 167
 // a single profile. It fetches a chunk of profiles concurrently, with a maximum
169 168
 // chunk size to limit its memory usage.
170
-func chunkedGrab(sources []profileSource, fetch plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI) (*profile.Profile, plugin.MappingSources, bool, int, error) {
169
+func chunkedGrab(sources []profileSource, fetch plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI, tr http.RoundTripper) (*profile.Profile, plugin.MappingSources, bool, int, error) {
171 170
 	const chunkSize = 64
172 171
 
173 172
 	var p *profile.Profile
@@ -180,7 +179,7 @@ func chunkedGrab(sources []profileSource, fetch plugin.Fetcher, obj plugin.ObjTo
180 179
 		if end > len(sources) {
181 180
 			end = len(sources)
182 181
 		}
183
-		chunkP, chunkMsrc, chunkSave, chunkCount, chunkErr := concurrentGrab(sources[start:end], fetch, obj, ui)
182
+		chunkP, chunkMsrc, chunkSave, chunkCount, chunkErr := concurrentGrab(sources[start:end], fetch, obj, ui, tr)
184 183
 		switch {
185 184
 		case chunkErr != nil:
186 185
 			return nil, nil, false, 0, chunkErr
@@ -204,13 +203,13 @@ func chunkedGrab(sources []profileSource, fetch plugin.Fetcher, obj plugin.ObjTo
204 203
 }
205 204
 
206 205
 // concurrentGrab fetches multiple profiles concurrently
207
-func concurrentGrab(sources []profileSource, fetch plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI) (*profile.Profile, plugin.MappingSources, bool, int, error) {
206
+func concurrentGrab(sources []profileSource, fetch plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI, tr http.RoundTripper) (*profile.Profile, plugin.MappingSources, bool, int, error) {
208 207
 	wg := sync.WaitGroup{}
209 208
 	wg.Add(len(sources))
210 209
 	for i := range sources {
211 210
 		go func(s *profileSource) {
212 211
 			defer wg.Done()
213
-			s.p, s.msrc, s.remote, s.err = grabProfile(s.source, s.addr, fetch, obj, ui)
212
+			s.p, s.msrc, s.remote, s.err = grabProfile(s.source, s.addr, fetch, obj, ui, tr)
214 213
 		}(&sources[i])
215 214
 	}
216 215
 	wg.Wait()
@@ -310,7 +309,7 @@ const testSourceAddress = "pproftest.local"
310 309
 // grabProfile fetches a profile. Returns the profile, sources for the
311 310
 // profile mappings, a bool indicating if the profile was fetched
312 311
 // remotely, and an error.
313
-func grabProfile(s *source, source string, fetcher plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI) (p *profile.Profile, msrc plugin.MappingSources, remote bool, err error) {
312
+func grabProfile(s *source, source string, fetcher plugin.Fetcher, obj plugin.ObjTool, ui plugin.UI, tr http.RoundTripper) (p *profile.Profile, msrc plugin.MappingSources, remote bool, err error) {
314 313
 	var src string
315 314
 	duration, timeout := time.Duration(s.Seconds)*time.Second, time.Duration(s.Timeout)*time.Second
316 315
 	if fetcher != nil {
@@ -321,7 +320,7 @@ func grabProfile(s *source, source string, fetcher plugin.Fetcher, obj plugin.Ob
321 320
 	}
322 321
 	if err != nil || p == nil {
323 322
 		// Fetch the profile over HTTP or from a file.
324
-		p, src, err = fetch(source, duration, timeout, ui)
323
+		p, src, err = fetch(source, duration, timeout, ui, tr)
325 324
 		if err != nil {
326 325
 			return
327 326
 		}
@@ -461,7 +460,7 @@ mapping:
461 460
 // fetch fetches a profile from source, within the timeout specified,
462 461
 // producing messages through the ui. It returns the profile and the
463 462
 // url of the actual source of the profile for remote profiles.
464
-func fetch(source string, duration, timeout time.Duration, ui plugin.UI) (p *profile.Profile, src string, err error) {
463
+func fetch(source string, duration, timeout time.Duration, ui plugin.UI, tr http.RoundTripper) (p *profile.Profile, src string, err error) {
465 464
 	var f io.ReadCloser
466 465
 
467 466
 	if sourceURL, timeout := adjustURL(source, duration, timeout); sourceURL != "" {
@@ -469,7 +468,7 @@ func fetch(source string, duration, timeout time.Duration, ui plugin.UI) (p *pro
469 468
 		if duration > 0 {
470 469
 			ui.Print(fmt.Sprintf("Please wait... (%v)", duration))
471 470
 		}
472
-		f, err = fetchURL(sourceURL, timeout)
471
+		f, err = fetchURL(sourceURL, timeout, tr)
473 472
 		src = sourceURL
474 473
 	} else if isPerfFile(source) {
475 474
 		f, err = convertPerfData(source, ui)
@@ -484,8 +483,12 @@ func fetch(source string, duration, timeout time.Duration, ui plugin.UI) (p *pro
484 483
 }
485 484
 
486 485
 // fetchURL fetches a profile from a URL using HTTP.
487
-func fetchURL(source string, timeout time.Duration) (io.ReadCloser, error) {
488
-	resp, err := httpGet(source, timeout)
486
+func fetchURL(source string, timeout time.Duration, tr http.RoundTripper) (io.ReadCloser, error) {
487
+	client := &http.Client{
488
+		Transport: tr,
489
+		Timeout:   timeout + 5*time.Second,
490
+	}
491
+	resp, err := client.Get(source)
489 492
 	if err != nil {
490 493
 		return nil, fmt.Errorf("http fetch: %v", err)
491 494
 	}
@@ -582,30 +585,3 @@ func adjustURL(source string, duration, timeout time.Duration) (string, time.Dur
582 585
 	u.RawQuery = values.Encode()
583 586
 	return u.String(), timeout
584 587
 }
585
-
586
-// httpGet is a wrapper around http.Get; it is defined as a variable
587
-// so it can be redefined during for testing.
588
-var httpGet = func(source string, timeout time.Duration) (*http.Response, error) {
589
-	url, err := url.Parse(source)
590
-	if err != nil {
591
-		return nil, err
592
-	}
593
-
594
-	var tlsConfig *tls.Config
595
-	if url.Scheme == "https+insecure" {
596
-		tlsConfig = &tls.Config{
597
-			InsecureSkipVerify: true,
598
-		}
599
-		url.Scheme = "https"
600
-		source = url.String()
601
-	}
602
-
603
-	client := &http.Client{
604
-		Transport: &http.Transport{
605
-			Proxy:                 http.ProxyFromEnvironment,
606
-			TLSClientConfig:       tlsConfig,
607
-			ResponseHeaderTimeout: timeout + 5*time.Second,
608
-		},
609
-	}
610
-	return client.Get(source)
611
-}

+ 139
- 26
internal/driver/fetch_test.go Vedi File

@@ -24,8 +24,8 @@ import (
24 24
 	"fmt"
25 25
 	"io/ioutil"
26 26
 	"math/big"
27
+	"net"
27 28
 	"net/http"
28
-	"net/url"
29 29
 	"os"
30 30
 	"path/filepath"
31 31
 	"reflect"
@@ -39,6 +39,7 @@ import (
39 39
 	"github.com/google/pprof/internal/plugin"
40 40
 	"github.com/google/pprof/internal/proftest"
41 41
 	"github.com/google/pprof/internal/symbolizer"
42
+	"github.com/google/pprof/internal/transport"
42 43
 	"github.com/google/pprof/profile"
43 44
 )
44 45
 
@@ -173,12 +174,6 @@ func (testFile) Close() error                                                 {
173 174
 
174 175
 func TestFetch(t *testing.T) {
175 176
 	const path = "testdata/"
176
-
177
-	// Intercept http.Get calls from HTTPFetcher.
178
-	savedHTTPGet := httpGet
179
-	defer func() { httpGet = savedHTTPGet }()
180
-	httpGet = stubHTTPGet
181
-
182 177
 	type testcase struct {
183 178
 		source, execName string
184 179
 	}
@@ -188,7 +183,7 @@ func TestFetch(t *testing.T) {
188 183
 		{path + "go.nomappings.crash", "/bin/gotest.exe"},
189 184
 		{"http://localhost/profile?file=cppbench.cpu", ""},
190 185
 	} {
191
-		p, _, _, err := grabProfile(&source{ExecName: tc.execName}, tc.source, nil, testObj{}, &proftest.TestUI{T: t})
186
+		p, _, _, err := grabProfile(&source{ExecName: tc.execName}, tc.source, nil, testObj{}, &proftest.TestUI{T: t}, &httpTransport{})
192 187
 		if err != nil {
193 188
 			t.Fatalf("%s: %s", tc.source, err)
194 189
 		}
@@ -449,8 +444,9 @@ func TestFetchWithBase(t *testing.T) {
449 444
 			f.args = tc.sources
450 445
 
451 446
 			o := setDefaults(&plugin.Options{
452
-				UI:      &proftest.TestUI{T: t, AllowRx: "Local symbolization failed|Some binary filenames not available"},
453
-				Flagset: f,
447
+				UI:            &proftest.TestUI{T: t, AllowRx: "Local symbolization failed|Some binary filenames not available"},
448
+				Flagset:       f,
449
+				HTTPTransport: transport.New(nil),
454 450
 			})
455 451
 			src, _, err := parseFlags(o)
456 452
 
@@ -503,19 +499,14 @@ func mappingSources(key, source string, start uint64) plugin.MappingSources {
503 499
 	}
504 500
 }
505 501
 
506
-// stubHTTPGet intercepts a call to http.Get and rewrites it to use
507
-// "file://" to get the profile directly from a file.
508
-func stubHTTPGet(source string, _ time.Duration) (*http.Response, error) {
509
-	url, err := url.Parse(source)
510
-	if err != nil {
511
-		return nil, err
512
-	}
502
+type httpTransport struct{}
513 503
 
514
-	values := url.Query()
504
+func (tr *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
505
+	values := req.URL.Query()
515 506
 	file := values.Get("file")
516 507
 
517 508
 	if file == "" {
518
-		return nil, fmt.Errorf("want .../file?profile, got %s", source)
509
+		return nil, fmt.Errorf("want .../file?profile, got %s", req.URL.String())
519 510
 	}
520 511
 
521 512
 	t := &http.Transport{}
@@ -532,7 +523,7 @@ func closedError() string {
532 523
 	return "use of closed"
533 524
 }
534 525
 
535
-func TestHttpsInsecure(t *testing.T) {
526
+func TestHTTPSInsecure(t *testing.T) {
536 527
 	if runtime.GOOS == "nacl" || runtime.GOOS == "js" {
537 528
 		t.Skip("test assumes tcp available")
538 529
 	}
@@ -553,7 +544,8 @@ func TestHttpsInsecure(t *testing.T) {
553 544
 	pprofVariables = baseVars.makeCopy()
554 545
 	defer func() { pprofVariables = baseVars }()
555 546
 
556
-	tlsConfig := &tls.Config{Certificates: []tls.Certificate{selfSignedCert(t)}}
547
+	tlsCert, _, _ := selfSignedCert(t, "")
548
+	tlsConfig := &tls.Config{Certificates: []tls.Certificate{tlsCert}}
557 549
 
558 550
 	l, err := tls.Listen("tcp", "localhost:0", tlsConfig)
559 551
 	if err != nil {
@@ -586,8 +578,9 @@ func TestHttpsInsecure(t *testing.T) {
586 578
 		Symbolize: "remote",
587 579
 	}
588 580
 	o := &plugin.Options{
589
-		Obj: &binutils.Binutils{},
590
-		UI:  &proftest.TestUI{T: t, AllowRx: "Saved profile in"},
581
+		Obj:           &binutils.Binutils{},
582
+		UI:            &proftest.TestUI{T: t, AllowRx: "Saved profile in"},
583
+		HTTPTransport: transport.New(nil),
591 584
 	}
592 585
 	o.Sym = &symbolizer.Symbolizer{Obj: o.Obj, UI: o.UI}
593 586
 	p, err := fetchProfiles(s, o)
@@ -600,7 +593,122 @@ func TestHttpsInsecure(t *testing.T) {
600 593
 	if len(p.Function) == 0 {
601 594
 		t.Fatalf("fetchProfiles(%s) got non-symbolized profile: len(p.Function)==0", address)
602 595
 	}
603
-	if err := checkProfileHasFunction(p, "TestHttpsInsecure"); err != nil {
596
+	if err := checkProfileHasFunction(p, "TestHTTPSInsecure"); err != nil {
597
+		t.Fatalf("fetchProfiles(%s) %v", address, err)
598
+	}
599
+}
600
+
601
+func TestHTTPSWithServerCertFetch(t *testing.T) {
602
+	if runtime.GOOS == "nacl" || runtime.GOOS == "js" {
603
+		t.Skip("test assumes tcp available")
604
+	}
605
+	saveHome := os.Getenv(homeEnv())
606
+	tempdir, err := ioutil.TempDir("", "home")
607
+	if err != nil {
608
+		t.Fatal("creating temp dir: ", err)
609
+	}
610
+	defer os.RemoveAll(tempdir)
611
+
612
+	// pprof writes to $HOME/pprof by default which is not necessarily
613
+	// writeable (e.g. on a Debian buildd) so set $HOME to something we
614
+	// know we can write to for the duration of the test.
615
+	os.Setenv(homeEnv(), tempdir)
616
+	defer os.Setenv(homeEnv(), saveHome)
617
+
618
+	baseVars := pprofVariables
619
+	pprofVariables = baseVars.makeCopy()
620
+	defer func() { pprofVariables = baseVars }()
621
+
622
+	cert, certBytes, keyBytes := selfSignedCert(t, "localhost")
623
+	cas := x509.NewCertPool()
624
+	cas.AppendCertsFromPEM(certBytes)
625
+
626
+	tlsConfig := &tls.Config{
627
+		RootCAs:      cas,
628
+		Certificates: []tls.Certificate{cert},
629
+		ClientAuth:   tls.RequireAndVerifyClientCert,
630
+		ClientCAs:    cas,
631
+	}
632
+
633
+	l, err := tls.Listen("tcp", "localhost:0", tlsConfig)
634
+	if err != nil {
635
+		t.Fatalf("net.Listen: got error %v, want no error", err)
636
+	}
637
+
638
+	donec := make(chan error, 1)
639
+	go func(donec chan<- error) {
640
+		donec <- http.Serve(l, nil)
641
+	}(donec)
642
+	defer func() {
643
+		if got, want := <-donec, closedError(); !strings.Contains(got.Error(), want) {
644
+			t.Fatalf("Serve got error %v, want %q", got, want)
645
+		}
646
+	}()
647
+	defer l.Close()
648
+
649
+	outputTempFile, err := ioutil.TempFile("", "profile_output")
650
+	if err != nil {
651
+		t.Fatalf("Failed to create tempfile: %v", err)
652
+	}
653
+	defer os.Remove(outputTempFile.Name())
654
+	defer outputTempFile.Close()
655
+
656
+	// Get port from the address, so request to the server can be made using
657
+	// the host name specified in certificates.
658
+	_, portStr, err := net.SplitHostPort(l.Addr().String())
659
+	if err != nil {
660
+		t.Fatalf("cannot get port from URL: %v", err)
661
+	}
662
+	address := "https://" + "localhost:" + portStr + "/debug/pprof/goroutine"
663
+	s := &source{
664
+		Sources:   []string{address},
665
+		Seconds:   10,
666
+		Timeout:   10,
667
+		Symbolize: "remote",
668
+	}
669
+
670
+	certTempFile, err := ioutil.TempFile("", "cert_output")
671
+	if err != nil {
672
+		t.Errorf("cannot create cert tempfile: %v", err)
673
+	}
674
+	defer os.Remove(certTempFile.Name())
675
+	defer certTempFile.Close()
676
+	certTempFile.Write(certBytes)
677
+
678
+	keyTempFile, err := ioutil.TempFile("", "key_output")
679
+	if err != nil {
680
+		t.Errorf("cannot create key tempfile: %v", err)
681
+	}
682
+	defer os.Remove(keyTempFile.Name())
683
+	defer keyTempFile.Close()
684
+	keyTempFile.Write(keyBytes)
685
+
686
+	f := &testFlags{
687
+		strings: map[string]string{
688
+			"tls_cert": certTempFile.Name(),
689
+			"tls_key":  keyTempFile.Name(),
690
+			"tls_ca":   certTempFile.Name(),
691
+		},
692
+	}
693
+	o := &plugin.Options{
694
+		Obj:           &binutils.Binutils{},
695
+		UI:            &proftest.TestUI{T: t, AllowRx: "Saved profile in"},
696
+		Flagset:       f,
697
+		HTTPTransport: transport.New(f),
698
+	}
699
+
700
+	o.Sym = &symbolizer.Symbolizer{Obj: o.Obj, UI: o.UI, Transport: o.HTTPTransport}
701
+	p, err := fetchProfiles(s, o)
702
+	if err != nil {
703
+		t.Fatal(err)
704
+	}
705
+	if len(p.SampleType) == 0 {
706
+		t.Fatalf("fetchProfiles(%s) got empty profile: len(p.SampleType)==0", address)
707
+	}
708
+	if len(p.Function) == 0 {
709
+		t.Fatalf("fetchProfiles(%s) got non-symbolized profile: len(p.Function)==0", address)
710
+	}
711
+	if err := checkProfileHasFunction(p, "TestHTTPSWithServerCertFetch"); err != nil {
604 712
 		t.Fatalf("fetchProfiles(%s) %v", address, err)
605 713
 	}
606 714
 }
@@ -614,7 +722,10 @@ func checkProfileHasFunction(p *profile.Profile, fname string) error {
614 722
 	return fmt.Errorf("got %s, want function %q", p.String(), fname)
615 723
 }
616 724
 
617
-func selfSignedCert(t *testing.T) tls.Certificate {
725
+// selfSignedCert generates a self-signed certificate, and returns the
726
+// generated certificate, and byte arrays containing the certificate and
727
+// key associated with the certificate.
728
+func selfSignedCert(t *testing.T, host string) (tls.Certificate, []byte, []byte) {
618 729
 	privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
619 730
 	if err != nil {
620 731
 		t.Fatalf("failed to generate private key: %v", err)
@@ -629,6 +740,8 @@ func selfSignedCert(t *testing.T) tls.Certificate {
629 740
 		SerialNumber: big.NewInt(1),
630 741
 		NotBefore:    time.Now(),
631 742
 		NotAfter:     time.Now().Add(10 * time.Minute),
743
+		IsCA:         true,
744
+		DNSNames:     []string{host},
632 745
 	}
633 746
 
634 747
 	b, err = x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, privKey.Public(), privKey)
@@ -641,5 +754,5 @@ func selfSignedCert(t *testing.T) tls.Certificate {
641 754
 	if err != nil {
642 755
 		t.Fatalf("failed to create TLS key pair: %v", err)
643 756
 	}
644
-	return cert
757
+	return cert, bc, bk
645 758
 }

+ 78
- 0
internal/driver/flags.go Vedi File

@@ -0,0 +1,78 @@
1
+package driver
2
+
3
+import (
4
+	"flag"
5
+	"strings"
6
+)
7
+
8
+// GoFlags implements the plugin.FlagSet interface.
9
+type GoFlags struct {
10
+	UsageMsgs []string
11
+}
12
+
13
+// Bool implements the plugin.FlagSet interface.
14
+func (*GoFlags) Bool(o string, d bool, c string) *bool {
15
+	return flag.Bool(o, d, c)
16
+}
17
+
18
+// Int implements the plugin.FlagSet interface.
19
+func (*GoFlags) Int(o string, d int, c string) *int {
20
+	return flag.Int(o, d, c)
21
+}
22
+
23
+// Float64 implements the plugin.FlagSet interface.
24
+func (*GoFlags) Float64(o string, d float64, c string) *float64 {
25
+	return flag.Float64(o, d, c)
26
+}
27
+
28
+// String implements the plugin.FlagSet interface.
29
+func (*GoFlags) String(o, d, c string) *string {
30
+	return flag.String(o, d, c)
31
+}
32
+
33
+// BoolVar implements the plugin.FlagSet interface.
34
+func (*GoFlags) BoolVar(b *bool, o string, d bool, c string) {
35
+	flag.BoolVar(b, o, d, c)
36
+}
37
+
38
+// IntVar implements the plugin.FlagSet interface.
39
+func (*GoFlags) IntVar(i *int, o string, d int, c string) {
40
+	flag.IntVar(i, o, d, c)
41
+}
42
+
43
+// Float64Var implements the plugin.FlagSet interface.
44
+// the value of the flag.
45
+func (*GoFlags) Float64Var(f *float64, o string, d float64, c string) {
46
+	flag.Float64Var(f, o, d, c)
47
+}
48
+
49
+// StringVar implements the plugin.FlagSet interface.
50
+func (*GoFlags) StringVar(s *string, o, d, c string) {
51
+	flag.StringVar(s, o, d, c)
52
+}
53
+
54
+// StringList implements the plugin.FlagSet interface.
55
+func (*GoFlags) StringList(o, d, c string) *[]*string {
56
+	return &[]*string{flag.String(o, d, c)}
57
+}
58
+
59
+// ExtraUsage implements the plugin.FlagSet interface.
60
+func (f *GoFlags) ExtraUsage() string {
61
+	return strings.Join(f.UsageMsgs, "\n")
62
+}
63
+
64
+// AddExtraUsage implements the plugin.FlagSet interface.
65
+func (f *GoFlags) AddExtraUsage(eu string) {
66
+	f.UsageMsgs = append(f.UsageMsgs, eu)
67
+}
68
+
69
+// Parse implements the plugin.FlagSet interface.
70
+func (*GoFlags) Parse(usage func()) []string {
71
+	flag.Usage = usage
72
+	flag.Parse()
73
+	args := flag.Args()
74
+	if len(args) == 0 {
75
+		usage()
76
+	}
77
+	return args
78
+}

+ 5
- 1
internal/driver/interactive_test.go Vedi File

@@ -23,6 +23,7 @@ import (
23 23
 	"github.com/google/pprof/internal/plugin"
24 24
 	"github.com/google/pprof/internal/proftest"
25 25
 	"github.com/google/pprof/internal/report"
26
+	"github.com/google/pprof/internal/transport"
26 27
 	"github.com/google/pprof/profile"
27 28
 )
28 29
 
@@ -41,7 +42,10 @@ func TestShell(t *testing.T) {
41 42
 
42 43
 	// Random interleave of independent scripts
43 44
 	pprofVariables = testVariables(savedVariables)
44
-	o := setDefaults(nil)
45
+
46
+	// pass in HTTPTransport when setting defaults, because otherwise default
47
+	// transport will try to add flags to the default flag set.
48
+	o := setDefaults(&plugin.Options{HTTPTransport: transport.New(nil)})
45 49
 	o.UI = newUI(t, interleave(script, 0))
46 50
 	if err := interactive(p, o); err != nil {
47 51
 		t.Error("first attempt:", err)

+ 6
- 58
internal/driver/options.go Vedi File

@@ -16,7 +16,6 @@ package driver
16 16
 
17 17
 import (
18 18
 	"bufio"
19
-	"flag"
20 19
 	"fmt"
21 20
 	"io"
22 21
 	"os"
@@ -25,6 +24,7 @@ import (
25 24
 	"github.com/google/pprof/internal/binutils"
26 25
 	"github.com/google/pprof/internal/plugin"
27 26
 	"github.com/google/pprof/internal/symbolizer"
27
+	"github.com/google/pprof/internal/transport"
28 28
 )
29 29
 
30 30
 // setDefaults returns a new plugin.Options with zero fields sets to
@@ -38,7 +38,7 @@ func setDefaults(o *plugin.Options) *plugin.Options {
38 38
 		d.Writer = oswriter{}
39 39
 	}
40 40
 	if d.Flagset == nil {
41
-		d.Flagset = goFlags{}
41
+		d.Flagset = &GoFlags{}
42 42
 	}
43 43
 	if d.Obj == nil {
44 44
 		d.Obj = &binutils.Binutils{}
@@ -46,67 +46,15 @@ func setDefaults(o *plugin.Options) *plugin.Options {
46 46
 	if d.UI == nil {
47 47
 		d.UI = &stdUI{r: bufio.NewReader(os.Stdin)}
48 48
 	}
49
+	if d.HTTPTransport == nil {
50
+		d.HTTPTransport = transport.New(d.Flagset)
51
+	}
49 52
 	if d.Sym == nil {
50
-		d.Sym = &symbolizer.Symbolizer{Obj: d.Obj, UI: d.UI}
53
+		d.Sym = &symbolizer.Symbolizer{Obj: d.Obj, UI: d.UI, Transport: d.HTTPTransport}
51 54
 	}
52 55
 	return d
53 56
 }
54 57
 
55
-// goFlags returns a flagset implementation based on the standard flag
56
-// package from the Go distribution. It implements the plugin.FlagSet
57
-// interface.
58
-type goFlags struct{}
59
-
60
-func (goFlags) Bool(o string, d bool, c string) *bool {
61
-	return flag.Bool(o, d, c)
62
-}
63
-
64
-func (goFlags) Int(o string, d int, c string) *int {
65
-	return flag.Int(o, d, c)
66
-}
67
-
68
-func (goFlags) Float64(o string, d float64, c string) *float64 {
69
-	return flag.Float64(o, d, c)
70
-}
71
-
72
-func (goFlags) String(o, d, c string) *string {
73
-	return flag.String(o, d, c)
74
-}
75
-
76
-func (goFlags) BoolVar(b *bool, o string, d bool, c string) {
77
-	flag.BoolVar(b, o, d, c)
78
-}
79
-
80
-func (goFlags) IntVar(i *int, o string, d int, c string) {
81
-	flag.IntVar(i, o, d, c)
82
-}
83
-
84
-func (goFlags) Float64Var(f *float64, o string, d float64, c string) {
85
-	flag.Float64Var(f, o, d, c)
86
-}
87
-
88
-func (goFlags) StringVar(s *string, o, d, c string) {
89
-	flag.StringVar(s, o, d, c)
90
-}
91
-
92
-func (goFlags) StringList(o, d, c string) *[]*string {
93
-	return &[]*string{flag.String(o, d, c)}
94
-}
95
-
96
-func (goFlags) ExtraUsage() string {
97
-	return ""
98
-}
99
-
100
-func (goFlags) Parse(usage func()) []string {
101
-	flag.Usage = usage
102
-	flag.Parse()
103
-	args := flag.Args()
104
-	if len(args) == 0 {
105
-		usage()
106
-	}
107
-	return args
108
-}
109
-
110 58
 type stdUI struct {
111 59
 	r *bufio.Reader
112 60
 }

+ 10
- 5
internal/plugin/plugin.go Vedi File

@@ -41,7 +41,8 @@ type Options struct {
41 41
 	//
42 42
 	// A common use for a custom HTTPServer is to provide custom
43 43
 	// authentication checks.
44
-	HTTPServer func(args *HTTPServerArgs) error
44
+	HTTPServer    func(args *HTTPServerArgs) error
45
+	HTTPTransport http.RoundTripper
45 46
 }
46 47
 
47 48
 // Writer provides a mechanism to write data under a certain name,
@@ -71,12 +72,16 @@ type FlagSet interface {
71 72
 	// single flag
72 73
 	StringList(name string, def string, usage string) *[]*string
73 74
 
74
-	// ExtraUsage returns any additional text that should be
75
-	// printed after the standard usage message.
76
-	// The typical use of ExtraUsage is to show any custom flags
77
-	// defined by the specific pprof plugins being used.
75
+	// ExtraUsage returns any additional text that should be printed after the
76
+	// standard usage message. The extra usage message returned includes all text
77
+	// added with AddExtraUsage().
78
+	// The typical use of ExtraUsage is to show any custom flags defined by the
79
+	// specific pprof plugins being used.
78 80
 	ExtraUsage() string
79 81
 
82
+	// AddExtraUsage appends additional text to the end of the extra usage message.
83
+	AddExtraUsage(eu string)
84
+
80 85
 	// Parse initializes the flags with their values for this run
81 86
 	// and returns the non-flag command line arguments.
82 87
 	// If an unknown flag is encountered or there are no arguments,

+ 9
- 22
internal/symbolizer/symbolizer.go Vedi File

@@ -18,7 +18,6 @@
18 18
 package symbolizer
19 19
 
20 20
 import (
21
-	"crypto/tls"
22 21
 	"fmt"
23 22
 	"io/ioutil"
24 23
 	"net/http"
@@ -35,8 +34,9 @@ import (
35 34
 
36 35
 // Symbolizer implements the plugin.Symbolize interface.
37 36
 type Symbolizer struct {
38
-	Obj plugin.ObjTool
39
-	UI  plugin.UI
37
+	Obj       plugin.ObjTool
38
+	UI        plugin.UI
39
+	Transport http.RoundTripper
40 40
 }
41 41
 
42 42
 // test taps for dependency injection
@@ -85,7 +85,10 @@ func (s *Symbolizer) Symbolize(mode string, sources plugin.MappingSources, p *pr
85 85
 		}
86 86
 	}
87 87
 	if remote {
88
-		if err = symbolzSymbolize(p, force, sources, postURL, s.UI); err != nil {
88
+		post := func(source, post string) ([]byte, error) {
89
+			return postURL(source, post, s.Transport)
90
+		}
91
+		if err = symbolzSymbolize(p, force, sources, post, s.UI); err != nil {
89 92
 			return err // Ran out of options.
90 93
 		}
91 94
 	}
@@ -95,25 +98,9 @@ func (s *Symbolizer) Symbolize(mode string, sources plugin.MappingSources, p *pr
95 98
 }
96 99
 
97 100
 // postURL issues a POST to a URL over HTTP.
98
-func postURL(source, post string) ([]byte, error) {
99
-	url, err := url.Parse(source)
100
-	if err != nil {
101
-		return nil, err
102
-	}
103
-
104
-	var tlsConfig *tls.Config
105
-	if url.Scheme == "https+insecure" {
106
-		tlsConfig = &tls.Config{
107
-			InsecureSkipVerify: true,
108
-		}
109
-		url.Scheme = "https"
110
-		source = url.String()
111
-	}
112
-
101
+func postURL(source, post string, tr http.RoundTripper) ([]byte, error) {
113 102
 	client := &http.Client{
114
-		Transport: &http.Transport{
115
-			TLSClientConfig: tlsConfig,
116
-		},
103
+		Transport: tr,
117 104
 	}
118 105
 	resp, err := client.Post(source, "application/octet-stream", strings.NewReader(post))
119 106
 	if err != nil {

+ 2
- 2
internal/symbolizer/symbolizer_test.go Vedi File

@@ -114,8 +114,8 @@ func TestSymbolization(t *testing.T) {
114 114
 	}
115 115
 
116 116
 	s := Symbolizer{
117
-		mockObjTool{},
118
-		&proftest.TestUI{T: t},
117
+		Obj: mockObjTool{},
118
+		UI:  &proftest.TestUI{T: t},
119 119
 	}
120 120
 	for i, tc := range []testcase{
121 121
 		{

+ 131
- 0
internal/transport/transport.go Vedi File

@@ -0,0 +1,131 @@
1
+// Copyright 2018 Google Inc. All Rights Reserved.
2
+//
3
+// Licensed under the Apache License, Version 2.0 (the "License");
4
+// you may not use this file except in compliance with the License.
5
+// You may obtain a copy of the License at
6
+//
7
+//     http://www.apache.org/licenses/LICENSE-2.0
8
+//
9
+// Unless required by applicable law or agreed to in writing, software
10
+// distributed under the License is distributed on an "AS IS" BASIS,
11
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+// See the License for the specific language governing permissions and
13
+// limitations under the License.
14
+
15
+// Package transport provides a mechanism to send requests with https cert,
16
+// key, and CA.
17
+package transport
18
+
19
+import (
20
+	"crypto/tls"
21
+	"crypto/x509"
22
+	"fmt"
23
+	"io/ioutil"
24
+	"net/http"
25
+	"sync"
26
+
27
+	"github.com/google/pprof/internal/plugin"
28
+)
29
+
30
+type transport struct {
31
+	cert       *string
32
+	key        *string
33
+	ca         *string
34
+	caCertPool *x509.CertPool
35
+	certs      []tls.Certificate
36
+	initOnce   sync.Once
37
+	initErr    error
38
+}
39
+
40
+const extraUsage = `    -tls_cert             TLS client certificate file for fetching profile and symbols
41
+    -tls_key              TLS private key file for fetching profile and symbols
42
+    -tls_ca               TLS CA certs file for fetching profile and symbols`
43
+
44
+// New returns a round tripper for making requests with the
45
+// specified cert, key, and ca. The flags tls_cert, tls_key, and tls_ca are
46
+// added to the flagset to allow a user to specify the cert, key, and ca. If
47
+// the flagset is nil, no flags will be added, and users will not be able to
48
+// use these flags.
49
+func New(flagset plugin.FlagSet) http.RoundTripper {
50
+	if flagset == nil {
51
+		return &transport{}
52
+	}
53
+	flagset.AddExtraUsage(extraUsage)
54
+	return &transport{
55
+		cert: flagset.String("tls_cert", "", "TLS client certificate file for fetching profile and symbols"),
56
+		key:  flagset.String("tls_key", "", "TLS private key file for fetching profile and symbols"),
57
+		ca:   flagset.String("tls_ca", "", "TLS CA certs file for fetching profile and symbols"),
58
+	}
59
+}
60
+
61
+// initialize uses the cert, key, and ca to initialize the certs
62
+// to use these when making requests.
63
+func (tr *transport) initialize() error {
64
+	var cert, key, ca string
65
+	if tr.cert != nil {
66
+		cert = *tr.cert
67
+	}
68
+	if tr.key != nil {
69
+		key = *tr.key
70
+	}
71
+	if tr.ca != nil {
72
+		ca = *tr.ca
73
+	}
74
+
75
+	if cert != "" && key != "" {
76
+		tlsCert, err := tls.LoadX509KeyPair(cert, key)
77
+		if err != nil {
78
+			return fmt.Errorf("could not load certificate/key pair specified by -tls_cert and -tls_key: %v", err)
79
+		}
80
+		tr.certs = []tls.Certificate{tlsCert}
81
+	} else if cert == "" && key != "" {
82
+		return fmt.Errorf("-tls_key is specified, so -tls_cert must also be specified")
83
+	} else if cert != "" && key == "" {
84
+		return fmt.Errorf("-tls_cert is specified, so -tls_key must also be specified")
85
+	}
86
+
87
+	if ca != "" {
88
+		caCertPool := x509.NewCertPool()
89
+		caCert, err := ioutil.ReadFile(ca)
90
+		if err != nil {
91
+			return fmt.Errorf("could not load CA specified by -tls_ca: %v", err)
92
+		}
93
+		caCertPool.AppendCertsFromPEM(caCert)
94
+		tr.caCertPool = caCertPool
95
+	}
96
+
97
+	return nil
98
+}
99
+
100
+// RoundTrip executes a single HTTP transaction, returning
101
+// a Response for the provided Request.
102
+func (tr *transport) RoundTrip(req *http.Request) (*http.Response, error) {
103
+	tr.initOnce.Do(func() {
104
+		tr.initErr = tr.initialize()
105
+	})
106
+	if tr.initErr != nil {
107
+		return nil, tr.initErr
108
+	}
109
+
110
+	tlsConfig := &tls.Config{
111
+		RootCAs:      tr.caCertPool,
112
+		Certificates: tr.certs,
113
+	}
114
+
115
+	if req.URL.Scheme == "https+insecure" {
116
+		// Make shallow copy of request, and req.URL, so the request's URL can be
117
+		// modified.
118
+		r := *req
119
+		*r.URL = *req.URL
120
+		req = &r
121
+		tlsConfig.InsecureSkipVerify = true
122
+		req.URL.Scheme = "https"
123
+	}
124
+
125
+	transport := http.Transport{
126
+		Proxy:           http.ProxyFromEnvironment,
127
+		TLSClientConfig: tlsConfig,
128
+	}
129
+
130
+	return transport.RoundTrip(req)
131
+}