Passed
Push — master ( f028f1...4fb92b )
by Stefano
02:12
created

cmd_test.startSocks5TestServer   A

Complexity

Conditions 5

Size

Total Lines 20
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 12
nop 1
dl 0
loc 20
rs 9.3333
c 0
b 0
f 0
1
package cmd_test
2
3
import (
4
	"io/ioutil"
5
	"net"
6
	"net/http"
7
	"net/http/httptest"
8
	"os"
9
	"sync"
10
	"syscall"
11
	"testing"
12
	"time"
13
14
	"github.com/armon/go-socks5"
15
	"github.com/stefanoj3/dirstalk/pkg/common/test"
16
	"github.com/stretchr/testify/assert"
17
)
18
19
const socks5TestServerHost = "127.0.0.1:8899"
20
21
func TestScanCommand(t *testing.T) {
22
	logger, loggerBuffer := test.NewLogger()
23
24
	c, err := createCommand(logger)
25
	assert.NoError(t, err)
26
	assert.NotNil(t, c)
27
28
	testServer, serverAssertion := test.NewServerWithAssertion(
29
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30
			if r.URL.Path == "/test/" {
31
				w.WriteHeader(http.StatusOK)
32
				return
33
			}
34
			if r.URL.Path == "/potato" {
35
				w.WriteHeader(http.StatusOK)
36
				return
37
			}
38
39
			if r.URL.Path == "/test/test/" {
40
				http.Redirect(w, r, "/potato", http.StatusMovedPermanently)
41
				return
42
			}
43
44
			w.WriteHeader(http.StatusNotFound)
45
		}),
46
	)
47
	defer testServer.Close()
48
49
	_, _, err = executeCommand(
50
		c,
51
		"scan",
52
		testServer.URL,
53
		"--dictionary",
54
		"testdata/dict2.txt",
55
		"-v",
56
		"--http-statuses-to-ignore",
57
		"404",
58
		"--http-timeout",
59
		"300",
60
	)
61
	assert.NoError(t, err)
62
63
	assert.Equal(t, 17, serverAssertion.Len())
64
65
	requestsMap := map[string]string{}
66
67
	serverAssertion.Range(func(_ int, r http.Request) {
68
		requestsMap[r.URL.Path] = r.Method
69
	})
70
71
	expectedRequests := map[string]string{
72
		"/test/":               http.MethodGet,
73
		"/test/home":           http.MethodGet,
74
		"/test/blabla":         http.MethodGet,
75
		"/test/home/index.php": http.MethodGet,
76
		"/potato":              http.MethodGet,
77
78
		"/potato/test/":          http.MethodGet,
79
		"/potato/home":           http.MethodGet,
80
		"/potato/home/index.php": http.MethodGet,
81
		"/potato/blabla":         http.MethodGet,
82
83
		"/test/test/test/":          http.MethodGet,
84
		"/test/test/home":           http.MethodGet,
85
		"/test/test/home/index.php": http.MethodGet,
86
		"/test/test/blabla":         http.MethodGet,
87
88
		"/test/test/": http.MethodGet,
89
90
		"/home":           http.MethodGet,
91
		"/blabla":         http.MethodGet,
92
		"/home/index.php": http.MethodGet,
93
	}
94
95
	assert.Equal(t, expectedRequests, requestsMap)
96
97
	expectedResultTree := `/
98
├── potato
99
└── test
100
    └── test
101
102
`
103
104
	assert.Contains(t, loggerBuffer.String(), expectedResultTree)
105
}
106
107
func TestScanShouldWriteOutput(t *testing.T) {
108
	logger, _ := test.NewLogger()
109
110
	c, err := createCommand(logger)
111
	assert.NoError(t, err)
112
	assert.NotNil(t, c)
113
114
	testServer, _ := test.NewServerWithAssertion(
115
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
116
			if r.URL.Path == "/home" {
117
				w.WriteHeader(http.StatusOK)
118
				return
119
			}
120
121
			w.WriteHeader(http.StatusNotFound)
122
		}),
123
	)
124
	defer testServer.Close()
125
126
	outputFilename := test.RandStringRunes(10)
127
	outputFilename = "testdata/out/" + outputFilename + ".txt"
128
129
	defer func() {
130
		err := os.Remove(outputFilename)
131
		if err != nil {
132
			panic("failed to remove file create during test: " + err.Error())
133
		}
134
	}()
135
136
	_, _, err = executeCommand(
137
		c,
138
		"scan",
139
		testServer.URL,
140
		"--dictionary",
141
		"testdata/dict2.txt",
142
		"--out",
143
		outputFilename,
144
	)
145
	assert.NoError(t, err)
146
147
	file, err := os.Open(outputFilename)
148
	assert.NoError(t, err)
149
150
	b, err := ioutil.ReadAll(file)
151
	assert.NoError(t, err, "failed to read file content")
152
153
	expected := `{"Target":{"Path":"home","Method":"GET","Depth":3},"StatusCode":200,"URL":{"Scheme":"http","Opaque":"","User":null,"Host":"` +
154
		testServer.Listener.Addr().String() +
155
		`","Path":"/home","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":""}}
156
`
157
	assert.Equal(t, expected, string(b))
158
159
	assert.NoError(t, file.Close(), "failed to close file")
160
}
161
162
func TestScanInvalidOutputFileShouldErr(t *testing.T) {
163
	logger, _ := test.NewLogger()
164
165
	c, err := createCommand(logger)
166
	assert.NoError(t, err)
167
	assert.NotNil(t, c)
168
169
	testServer, _ := test.NewServerWithAssertion(
170
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171
			if r.URL.Path == "/home" {
172
				w.WriteHeader(http.StatusOK)
173
				return
174
			}
175
176
			w.WriteHeader(http.StatusNotFound)
177
		}),
178
	)
179
	defer testServer.Close()
180
181
	_, _, err = executeCommand(
182
		c,
183
		"scan",
184
		testServer.URL,
185
		"--dictionary",
186
		"testdata/dict2.txt",
187
		"--out",
188
		"/root/blabla/123/gibberish/123",
189
	)
190
	assert.Error(t, err)
191
	assert.Contains(t, err.Error(), "failed to create output saver")
192
}
193
194
func TestScanWithInvalidStatusesToIgnoreShouldErr(t *testing.T) {
195
	logger, _ := test.NewLogger()
196
197
	c, err := createCommand(logger)
198
	assert.NoError(t, err)
199
	assert.NotNil(t, c)
200
201
	testServer, serverAssertion := test.NewServerWithAssertion(
202
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
203
	)
204
	defer testServer.Close()
205
206
	_, _, err = executeCommand(
207
		c,
208
		"scan",
209
		testServer.URL,
210
		"--dictionary",
211
		"testdata/dict2.txt",
212
		"-v",
213
		"--http-statuses-to-ignore",
214
		"300,gibberish,404",
215
		"--http-timeout",
216
		"300",
217
	)
218
	assert.Error(t, err)
219
	assert.Contains(t, err.Error(), "strconv.Atoi: parsing")
220
	assert.Contains(t, err.Error(), "gibberish")
221
222
	assert.Equal(t, 0, serverAssertion.Len())
223
}
224
225
func TestScanWithNoTargetShouldErr(t *testing.T) {
226
	logger, _ := test.NewLogger()
227
228
	c, err := createCommand(logger)
229
	assert.NoError(t, err)
230
	assert.NotNil(t, c)
231
232
	_, _, err = executeCommand(c, "scan", "--dictionary", "testdata/dict2.txt")
233
	assert.Error(t, err)
234
	assert.Contains(t, err.Error(), "no URL provided")
235
}
236
237
func TestScanWithInvalidTargetShouldErr(t *testing.T) {
238
	logger, _ := test.NewLogger()
239
240
	c, err := createCommand(logger)
241
	assert.NoError(t, err)
242
	assert.NotNil(t, c)
243
244
	_, _, err = executeCommand(c, "scan", "--dictionary", "testdata/dict2.txt", "localhost%%2")
245
	assert.Error(t, err)
246
	assert.Contains(t, err.Error(), "invalid URI")
247
}
248
249
func TestScanCommandCanBeInterrupted(t *testing.T) {
250
	logger, loggerBuffer := test.NewLogger()
251
252
	c, err := createCommand(logger)
253
	assert.NoError(t, err)
254
	assert.NotNil(t, c)
255
256
	testServer, serverAssertion := test.NewServerWithAssertion(
257
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
258
			time.Sleep(time.Millisecond * 650)
259
260
			if r.URL.Path == "/test/" {
261
				w.WriteHeader(http.StatusOK)
262
				return
263
			}
264
265
			w.WriteHeader(http.StatusNotFound)
266
		}),
267
	)
268
	defer testServer.Close()
269
270
	go func() {
271
		time.Sleep(time.Millisecond * 200)
272
		_ = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
273
	}()
274
275
	_, _, err = executeCommand(
276
		c,
277
		"scan",
278
		testServer.URL,
279
		"--dictionary",
280
		"testdata/dict2.txt",
281
		"-v",
282
		"--http-timeout",
283
		"900",
284
	)
285
	assert.NoError(t, err)
286
287
	assert.True(t, serverAssertion.Len() > 0)
288
	assert.Contains(t, loggerBuffer.String(), "Received sigint")
289
}
290
291
func TestScanWithRemoteDictionary(t *testing.T) {
292
	logger, _ := test.NewLogger()
293
294
	c, err := createCommand(logger)
295
	assert.NoError(t, err)
296
	assert.NotNil(t, c)
297
298
	dictionaryServer := httptest.NewServer(
299
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
300
			dict := `home
301
home/index.php
302
blabla
303
`
304
			w.WriteHeader(http.StatusOK)
305
			_, _ = w.Write([]byte(dict))
306
		}),
307
	)
308
	defer dictionaryServer.Close()
309
310
	testServer, serverAssertion := test.NewServerWithAssertion(
311
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
312
			w.WriteHeader(http.StatusNotFound)
313
		}),
314
	)
315
	defer testServer.Close()
316
317
	_, _, err = executeCommand(
318
		c,
319
		"scan",
320
		testServer.URL,
321
		"--dictionary",
322
		dictionaryServer.URL,
323
		"--http-timeout",
324
		"300",
325
	)
326
	assert.NoError(t, err)
327
328
	assert.Equal(t, 3, serverAssertion.Len())
329
}
330
331
func TestScanWithUserAgentFlag(t *testing.T) {
332
	const testUserAgent = "my_test_user_agent"
333
334
	logger, loggerBuffer := test.NewLogger()
335
336
	c, err := createCommand(logger)
337
	assert.NoError(t, err)
338
	assert.NotNil(t, c)
339
340
	testServer, serverAssertion := test.NewServerWithAssertion(
341
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
342
			w.WriteHeader(http.StatusNotFound)
343
		}),
344
	)
345
	defer testServer.Close()
346
347
	_, _, err = executeCommand(
348
		c,
349
		"scan",
350
		testServer.URL,
351
		"--user-agent",
352
		testUserAgent,
353
		"--dictionary",
354
		"testdata/dict.txt",
355
		"--http-timeout",
356
		"300",
357
	)
358
	assert.NoError(t, err)
359
360
	assert.Equal(t, 3, serverAssertion.Len())
361
	serverAssertion.Range(func(_ int, r http.Request) {
362
		assert.Equal(t, testUserAgent, r.Header.Get("User-Agent"))
363
	})
364
365
	// to ensure we print the user agent to the cli
366
	assert.Contains(t, loggerBuffer.String(), testUserAgent)
367
}
368
369
func TestScanWithCookies(t *testing.T) {
370
	logger, loggerBuffer := test.NewLogger()
371
372
	c, err := createCommand(logger)
373
	assert.NoError(t, err)
374
	assert.NotNil(t, c)
375
376
	testServer, serverAssertion := test.NewServerWithAssertion(
377
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
378
	)
379
	defer testServer.Close()
380
381
	_, _, err = executeCommand(
382
		c,
383
		"scan",
384
		testServer.URL,
385
		"--cookie",
386
		"name1=val1",
387
		"--cookie",
388
		"name2=val2",
389
		"--dictionary",
390
		"testdata/dict.txt",
391
		"--http-timeout",
392
		"300",
393
	)
394
	assert.NoError(t, err)
395
396
	serverAssertion.Range(func(_ int, r http.Request) {
397
		assert.Equal(t, 2, len(r.Cookies()))
398
399
		assert.Equal(t, r.Cookies()[0].Name, "name1")
400
		assert.Equal(t, r.Cookies()[0].Value, "val1")
401
402
		assert.Equal(t, r.Cookies()[1].Name, "name2")
403
		assert.Equal(t, r.Cookies()[1].Value, "val2")
404
	})
405
406
	// to ensure we print the cookies to the cli
407
	assert.Contains(t, loggerBuffer.String(), "name1=val1")
408
	assert.Contains(t, loggerBuffer.String(), "name2=val2")
409
}
410
411
func TestWhenProvidingCookiesInWrongFormatShouldErr(t *testing.T) {
412
	const malformedCookie = "gibberish"
413
414
	logger, _ := test.NewLogger()
415
416
	c, err := createCommand(logger)
417
	assert.NoError(t, err)
418
	assert.NotNil(t, c)
419
420
	testServer, serverAssertion := test.NewServerWithAssertion(
421
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
422
			w.WriteHeader(http.StatusNotFound)
423
		}),
424
	)
425
	defer testServer.Close()
426
427
	_, _, err = executeCommand(
428
		c,
429
		"scan",
430
		testServer.URL,
431
		"--cookie",
432
		malformedCookie,
433
		"--dictionary",
434
		"testdata/dict.txt",
435
	)
436
	assert.Error(t, err)
437
	assert.Contains(t, err.Error(), "cookie format is invalid")
438
	assert.Contains(t, err.Error(), malformedCookie)
439
440
	assert.Equal(t, 0, serverAssertion.Len())
441
}
442
443
func TestScanWithCookieJar(t *testing.T) {
444
	const (
445
		serverCookieName  = "server_cookie_name"
446
		serverCookieValue = "server_cookie_value"
447
	)
448
449
	logger, _ := test.NewLogger()
450
451
	c, err := createCommand(logger)
452
	assert.NoError(t, err)
453
	assert.NotNil(t, c)
454
455
	once := sync.Once{}
456
	testServer, serverAssertion := test.NewServerWithAssertion(
457
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
458
			once.Do(func() {
459
				http.SetCookie(
460
					w,
461
					&http.Cookie{
462
						Name:    serverCookieName,
463
						Value:   serverCookieValue,
464
						Expires: time.Now().AddDate(0, 1, 0),
465
					},
466
				)
467
			})
468
		}),
469
	)
470
	defer testServer.Close()
471
472
	_, _, err = executeCommand(
473
		c,
474
		"scan",
475
		testServer.URL,
476
		"--use-cookie-jar",
477
		"--dictionary",
478
		"testdata/dict.txt",
479
		"--http-timeout",
480
		"300",
481
		"-t",
482
		"1",
483
	)
484
	assert.NoError(t, err)
485
486
	serverAssertion.Range(func(index int, r http.Request) {
487
		if index == 0 { // first request should have no cookies
488
			assert.Equal(t, 0, len(r.Cookies()))
489
			return
490
		}
491
492
		assert.Equal(t, 1, len(r.Cookies()))
493
		assert.Equal(t, r.Cookies()[0].Name, serverCookieName)
494
		assert.Equal(t, r.Cookies()[0].Value, serverCookieValue)
495
	})
496
}
497
498
func TestScanWithUnknownFlagShouldErr(t *testing.T) {
499
	logger, _ := test.NewLogger()
500
501
	c, err := createCommand(logger)
502
	assert.NoError(t, err)
503
	assert.NotNil(t, c)
504
505
	testServer, serverAssertion := test.NewServerWithAssertion(
506
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
507
	)
508
	defer testServer.Close()
509
510
	_, _, err = executeCommand(
511
		c,
512
		"scan",
513
		testServer.URL,
514
		"--gibberishflag",
515
		"--dictionary",
516
		"testdata/dict.txt",
517
	)
518
	assert.Error(t, err)
519
	assert.Contains(t, err.Error(), "unknown flag")
520
521
	assert.Equal(t, 0, serverAssertion.Len())
522
}
523
524
func TestScanWithHeaders(t *testing.T) {
525
	logger, loggerBuffer := test.NewLogger()
526
527
	c, err := createCommand(logger)
528
	assert.NoError(t, err)
529
	assert.NotNil(t, c)
530
531
	testServer, serverAssertion := test.NewServerWithAssertion(
532
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
533
	)
534
	defer testServer.Close()
535
536
	_, _, err = executeCommand(
537
		c,
538
		"scan",
539
		testServer.URL,
540
		"--header",
541
		"Accept-Language: en-US,en;q=0.5",
542
		"--header",
543
		`"Authorization: Bearer 123"`,
544
		"--dictionary",
545
		"testdata/dict.txt",
546
		"--http-timeout",
547
		"300",
548
	)
549
	assert.NoError(t, err)
550
551
	serverAssertion.Range(func(_ int, r http.Request) {
552
		assert.Equal(t, 2, len(r.Header))
553
554
		assert.Equal(t, "en-US,en;q=0.5", r.Header.Get("Accept-Language"))
555
		assert.Equal(t, "Bearer 123", r.Header.Get("Authorization"))
556
	})
557
558
	// to ensure we print the headers to the cli
559
	assert.Contains(t, loggerBuffer.String(), "Accept-Language")
560
	assert.Contains(t, loggerBuffer.String(), "Authorization")
561
	assert.Contains(t, loggerBuffer.String(), "Bearer 123")
562
}
563
564
func TestScanWithMalformedHeaderShouldErr(t *testing.T) {
565
	const malformedHeader = "gibberish"
566
567
	logger, _ := test.NewLogger()
568
569
	c, err := createCommand(logger)
570
	assert.NoError(t, err)
571
	assert.NotNil(t, c)
572
573
	testServer, serverAssertion := test.NewServerWithAssertion(
574
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
575
	)
576
	defer testServer.Close()
577
578
	_, _, err = executeCommand(
579
		c,
580
		"scan",
581
		testServer.URL,
582
		"--header",
583
		"Accept-Language: en-US,en;q=0.5",
584
		"--header",
585
		malformedHeader,
586
		"--dictionary",
587
		"testdata/dict.txt",
588
	)
589
	assert.Error(t, err)
590
	assert.Contains(t, err.Error(), malformedHeader)
591
	assert.Contains(t, err.Error(), "header is in invalid format")
592
593
	assert.Equal(t, 0, serverAssertion.Len())
594
}
595
596
func TestStartScanWithSocks5ShouldFindResultsWhenAServerIsAvailable(t *testing.T) {
597
	logger, _ := test.NewLogger()
598
599
	c, err := createCommand(logger)
600
	assert.NoError(t, err)
601
	assert.NotNil(t, c)
602
603
	testServer, serverAssertion := test.NewServerWithAssertion(
604
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
605
			w.WriteHeader(http.StatusNotFound)
606
		}),
607
	)
608
	defer testServer.Close()
609
610
	socks5Server := startSocks5TestServer(t)
611
	defer socks5Server.Close()
612
613
	_, _, err = executeCommand(
614
		c,
615
		"scan",
616
		testServer.URL,
617
		"--dictionary",
618
		"testdata/dict.txt",
619
		"-v",
620
		"--http-timeout",
621
		"300",
622
		"--socks5",
623
		socks5TestServerHost,
624
	)
625
	assert.NoError(t, err)
626
627
	assert.Equal(t, 3, serverAssertion.Len())
628
}
629
630
func TestShouldFailToScanWithAnUnreachableSocks5Server(t *testing.T) {
631
	logger, loggerBuffer := test.NewLogger()
632
633
	c, err := createCommand(logger)
634
	assert.NoError(t, err)
635
	assert.NotNil(t, c)
636
637
	testServer, serverAssertion := test.NewServerWithAssertion(
638
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
639
			w.WriteHeader(http.StatusNotFound)
640
		}),
641
	)
642
	defer testServer.Close()
643
644
	socks5Server := startSocks5TestServer(t)
645
	defer socks5Server.Close()
646
647
	_, _, err = executeCommand(
648
		c,
649
		"scan",
650
		testServer.URL,
651
		"--dictionary",
652
		"testdata/dict.txt",
653
		"-v",
654
		"--http-timeout",
655
		"300",
656
		"--socks5",
657
		"127.0.0.1:9555", // invalid
658
	)
659
	assert.NoError(t, err)
660
661
	assert.Equal(t, 0, serverAssertion.Len())
662
	assert.Contains(t, loggerBuffer.String(), "failed to perform request")
663
	assert.Contains(t, loggerBuffer.String(), "socks connect tcp")
664
	assert.Contains(t, loggerBuffer.String(), "connect: connection refused")
665
}
666
667
func TestShouldFailToStartWithAnInvalidSocks5Address(t *testing.T) {
668
	logger, _ := test.NewLogger()
669
670
	c, err := createCommand(logger)
671
	assert.NoError(t, err)
672
	assert.NotNil(t, c)
673
674
	testServer, serverAssertion := test.NewServerWithAssertion(
675
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
676
			w.WriteHeader(http.StatusNotFound)
677
		}),
678
	)
679
	defer testServer.Close()
680
681
	_, _, err = executeCommand(
682
		c,
683
		"scan",
684
		testServer.URL,
685
		"--dictionary",
686
		"testdata/dict.txt",
687
		"-v",
688
		"--http-timeout",
689
		"300",
690
		"--socks5",
691
		"localhost%%2", // invalid
692
	)
693
	assert.Error(t, err)
694
	assert.Contains(t, err.Error(), "invalid URL escape")
695
696
	assert.Equal(t, 0, serverAssertion.Len())
697
}
698
699
func startSocks5TestServer(t *testing.T) net.Listener {
700
	conf := &socks5.Config{}
701
	server, err := socks5.New(conf)
702
	if err != nil {
703
		t.Fatalf("failed to create socks5: %s", err.Error())
704
	}
705
706
	listener, err := net.Listen("tcp", socks5TestServerHost)
707
	if err != nil {
708
		t.Fatalf("failed to create listener: %s", err.Error())
709
	}
710
711
	go func() {
712
		// Create SOCKS5 proxy on localhost port 8000
713
		if err := server.Serve(listener); err != nil {
714
			t.Logf("socks5 stopped serving: %s", err.Error())
715
		}
716
	}()
717
718
	return listener
719
}
720