Passed
Push — master ( fc31a0...e87dc6 )
by Stefano
02:12
created

cmd_test.TestScanWithInvalidTargetShouldErr   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 1
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
package cmd_test
2
3
import (
4
	"net"
5
	"net/http"
6
	"net/http/httptest"
7
	"sync"
8
	"syscall"
9
	"testing"
10
	"time"
11
12
	"github.com/armon/go-socks5"
13
	"github.com/stefanoj3/dirstalk/pkg/common/test"
14
	"github.com/stretchr/testify/assert"
15
)
16
17
const socks5TestServerHost = "127.0.0.1:8899"
18
19
func TestScanCommand(t *testing.T) {
20
	logger, _ := test.NewLogger()
21
22
	c, err := createCommand(logger)
23
	assert.NoError(t, err)
24
	assert.NotNil(t, c)
25
26
	testServer, serverAssertion := test.NewServerWithAssertion(
27
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28
			if r.URL.Path == "/test/" {
29
				w.WriteHeader(http.StatusOK)
30
				return
31
			}
32
33
			w.WriteHeader(http.StatusNotFound)
34
		}),
35
	)
36
	defer testServer.Close()
37
38
	_, _, err = executeCommand(
39
		c,
40
		"scan",
41
		testServer.URL,
42
		"--dictionary",
43
		"testdata/dict2.txt",
44
		"-v",
45
		"--http-timeout",
46
		"300",
47
	)
48
	assert.NoError(t, err)
49
50
	assert.Equal(t, 8, serverAssertion.Len())
51
52
	requestsMap := map[string]string{}
53
54
	serverAssertion.Range(func(_ int, r http.Request) {
55
		requestsMap[r.URL.Path] = r.Method
56
	})
57
58
	expectedRequests := map[string]string{
59
		"/test/":               http.MethodGet,
60
		"/test/home":           http.MethodGet,
61
		"/test/blabla":         http.MethodGet,
62
		"/test/home/index.php": http.MethodGet,
63
		"/test/test/":          http.MethodGet,
64
65
		"/home":           http.MethodGet,
66
		"/blabla":         http.MethodGet,
67
		"/home/index.php": http.MethodGet,
68
	}
69
70
	assert.Equal(t, expectedRequests, requestsMap)
71
}
72
73
func TestScanWithNoTargetShouldErr(t *testing.T) {
74
	logger, _ := test.NewLogger()
75
76
	c, err := createCommand(logger)
77
	assert.NoError(t, err)
78
	assert.NotNil(t, c)
79
80
	_, _, err = executeCommand(c, "scan", "--dictionary", "testdata/dict2.txt")
81
	assert.Error(t, err)
82
	assert.Contains(t, err.Error(), "no URL provided")
83
}
84
85
func TestScanWithInvalidTargetShouldErr(t *testing.T) {
86
	logger, _ := test.NewLogger()
87
88
	c, err := createCommand(logger)
89
	assert.NoError(t, err)
90
	assert.NotNil(t, c)
91
92
	_, _, err = executeCommand(c, "scan", "--dictionary", "testdata/dict2.txt", "localhost%%2")
93
	assert.Error(t, err)
94
	assert.Contains(t, err.Error(), "invalid URI")
95
}
96
97
func TestScanCommandCanBeInterrupted(t *testing.T) {
98
	logger, loggerBuffer := test.NewLogger()
99
100
	c, err := createCommand(logger)
101
	assert.NoError(t, err)
102
	assert.NotNil(t, c)
103
104
	testServer, serverAssertion := test.NewServerWithAssertion(
105
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
106
			time.Sleep(time.Millisecond * 650)
107
108
			if r.URL.Path == "/test/" {
109
				w.WriteHeader(http.StatusOK)
110
				return
111
			}
112
113
			w.WriteHeader(http.StatusNotFound)
114
		}),
115
	)
116
	defer testServer.Close()
117
118
	go func() {
119
		time.Sleep(time.Millisecond * 200)
120
		_ = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
121
	}()
122
123
	_, _, err = executeCommand(
124
		c,
125
		"scan",
126
		testServer.URL,
127
		"--dictionary",
128
		"testdata/dict2.txt",
129
		"-v",
130
		"--http-timeout",
131
		"900",
132
	)
133
	assert.NoError(t, err)
134
135
	assert.True(t, serverAssertion.Len() > 0)
136
	assert.Contains(t, loggerBuffer.String(), "Received sigint")
137
}
138
139
func TestScanWithRemoteDictionary(t *testing.T) {
140
	logger, _ := test.NewLogger()
141
142
	c, err := createCommand(logger)
143
	assert.NoError(t, err)
144
	assert.NotNil(t, c)
145
146
	dictionaryServer := httptest.NewServer(
147
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
148
			dict := `home
149
home/index.php
150
blabla
151
`
152
			w.WriteHeader(http.StatusOK)
153
			_, _ = w.Write([]byte(dict))
154
		}),
155
	)
156
	defer dictionaryServer.Close()
157
158
	testServer, serverAssertion := test.NewServerWithAssertion(
159
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160
			w.WriteHeader(http.StatusNotFound)
161
		}),
162
	)
163
	defer testServer.Close()
164
165
	_, _, err = executeCommand(
166
		c,
167
		"scan",
168
		testServer.URL,
169
		"--dictionary",
170
		dictionaryServer.URL,
171
		"--http-timeout",
172
		"300",
173
	)
174
	assert.NoError(t, err)
175
176
	assert.Equal(t, 3, serverAssertion.Len())
177
}
178
179
func TestScanWithUserAgentFlag(t *testing.T) {
180
	const testUserAgent = "my_test_user_agent"
181
182
	logger, loggerBuffer := test.NewLogger()
183
184
	c, err := createCommand(logger)
185
	assert.NoError(t, err)
186
	assert.NotNil(t, c)
187
188
	testServer, serverAssertion := test.NewServerWithAssertion(
189
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
190
			w.WriteHeader(http.StatusNotFound)
191
		}),
192
	)
193
	defer testServer.Close()
194
195
	_, _, err = executeCommand(
196
		c,
197
		"scan",
198
		testServer.URL,
199
		"--user-agent",
200
		testUserAgent,
201
		"--dictionary",
202
		"testdata/dict.txt",
203
		"--http-timeout",
204
		"300",
205
	)
206
	assert.NoError(t, err)
207
208
	assert.Equal(t, 3, serverAssertion.Len())
209
	serverAssertion.Range(func(_ int, r http.Request) {
210
		assert.Equal(t, testUserAgent, r.Header.Get("User-Agent"))
211
	})
212
213
	// to ensure we print the user agent to the cli
214
	assert.Contains(t, loggerBuffer.String(), testUserAgent)
215
}
216
217
func TestScanWithCookies(t *testing.T) {
218
	logger, loggerBuffer := test.NewLogger()
219
220
	c, err := createCommand(logger)
221
	assert.NoError(t, err)
222
	assert.NotNil(t, c)
223
224
	testServer, serverAssertion := test.NewServerWithAssertion(
225
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
226
	)
227
	defer testServer.Close()
228
229
	_, _, err = executeCommand(
230
		c,
231
		"scan",
232
		testServer.URL,
233
		"--cookie",
234
		"name1=val1",
235
		"--cookie",
236
		"name2=val2",
237
		"--dictionary",
238
		"testdata/dict.txt",
239
		"--http-timeout",
240
		"300",
241
	)
242
	assert.NoError(t, err)
243
244
	serverAssertion.Range(func(_ int, r http.Request) {
245
		assert.Equal(t, 2, len(r.Cookies()))
246
247
		assert.Equal(t, r.Cookies()[0].Name, "name1")
248
		assert.Equal(t, r.Cookies()[0].Value, "val1")
249
250
		assert.Equal(t, r.Cookies()[1].Name, "name2")
251
		assert.Equal(t, r.Cookies()[1].Value, "val2")
252
	})
253
254
	// to ensure we print the cookies to the cli
255
	assert.Contains(t, loggerBuffer.String(), "name1=val1")
256
	assert.Contains(t, loggerBuffer.String(), "name2=val2")
257
}
258
259
func TestWhenProvidingCookiesInWrongFormatShouldErr(t *testing.T) {
260
	const malformedCookie = "gibberish"
261
262
	logger, _ := test.NewLogger()
263
264
	c, err := createCommand(logger)
265
	assert.NoError(t, err)
266
	assert.NotNil(t, c)
267
268
	testServer, serverAssertion := test.NewServerWithAssertion(
269
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
270
			w.WriteHeader(http.StatusNotFound)
271
		}),
272
	)
273
	defer testServer.Close()
274
275
	_, _, err = executeCommand(
276
		c,
277
		"scan",
278
		testServer.URL,
279
		"--cookie",
280
		malformedCookie,
281
		"--dictionary",
282
		"testdata/dict.txt",
283
	)
284
	assert.Error(t, err)
285
	assert.Contains(t, err.Error(), "cookie format is invalid")
286
	assert.Contains(t, err.Error(), malformedCookie)
287
288
	assert.Equal(t, 0, serverAssertion.Len())
289
}
290
291
func TestScanWithCookieJar(t *testing.T) {
292
	const (
293
		serverCookieName  = "server_cookie_name"
294
		serverCookieValue = "server_cookie_value"
295
	)
296
297
	logger, _ := test.NewLogger()
298
299
	c, err := createCommand(logger)
300
	assert.NoError(t, err)
301
	assert.NotNil(t, c)
302
303
	once := sync.Once{}
304
	testServer, serverAssertion := test.NewServerWithAssertion(
305
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
306
			once.Do(func() {
307
				http.SetCookie(
308
					w,
309
					&http.Cookie{
310
						Name:    serverCookieName,
311
						Value:   serverCookieValue,
312
						Expires: time.Now().AddDate(0, 1, 0),
313
					},
314
				)
315
			})
316
		}),
317
	)
318
	defer testServer.Close()
319
320
	_, _, err = executeCommand(
321
		c,
322
		"scan",
323
		testServer.URL,
324
		"--use-cookie-jar",
325
		"--dictionary",
326
		"testdata/dict.txt",
327
		"--http-timeout",
328
		"300",
329
		"-t",
330
		"1",
331
	)
332
	assert.NoError(t, err)
333
334
	serverAssertion.Range(func(index int, r http.Request) {
335
		if index == 0 { // first request should have no cookies
336
			assert.Equal(t, 0, len(r.Cookies()))
337
			return
338
		}
339
340
		assert.Equal(t, 1, len(r.Cookies()))
341
		assert.Equal(t, r.Cookies()[0].Name, serverCookieName)
342
		assert.Equal(t, r.Cookies()[0].Value, serverCookieValue)
343
	})
344
}
345
346
func TestScanWithUnknownFlagShouldErr(t *testing.T) {
347
	logger, _ := test.NewLogger()
348
349
	c, err := createCommand(logger)
350
	assert.NoError(t, err)
351
	assert.NotNil(t, c)
352
353
	testServer, serverAssertion := test.NewServerWithAssertion(
354
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
355
	)
356
	defer testServer.Close()
357
358
	_, _, err = executeCommand(
359
		c,
360
		"scan",
361
		testServer.URL,
362
		"--gibberishflag",
363
		"--dictionary",
364
		"testdata/dict.txt",
365
	)
366
	assert.Error(t, err)
367
	assert.Contains(t, err.Error(), "unknown flag")
368
369
	assert.Equal(t, 0, serverAssertion.Len())
370
}
371
372
func TestScanWithHeaders(t *testing.T) {
373
	logger, loggerBuffer := test.NewLogger()
374
375
	c, err := createCommand(logger)
376
	assert.NoError(t, err)
377
	assert.NotNil(t, c)
378
379
	testServer, serverAssertion := test.NewServerWithAssertion(
380
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
381
	)
382
	defer testServer.Close()
383
384
	_, _, err = executeCommand(
385
		c,
386
		"scan",
387
		testServer.URL,
388
		"--header",
389
		"Accept-Language: en-US,en;q=0.5",
390
		"--header",
391
		`"Authorization: Bearer 123"`,
392
		"--dictionary",
393
		"testdata/dict.txt",
394
		"--http-timeout",
395
		"300",
396
	)
397
	assert.NoError(t, err)
398
399
	serverAssertion.Range(func(_ int, r http.Request) {
400
		assert.Equal(t, 2, len(r.Header))
401
402
		assert.Equal(t, "en-US,en;q=0.5", r.Header.Get("Accept-Language"))
403
		assert.Equal(t, "Bearer 123", r.Header.Get("Authorization"))
404
	})
405
406
	// to ensure we print the headers to the cli
407
	assert.Contains(t, loggerBuffer.String(), "Accept-Language")
408
	assert.Contains(t, loggerBuffer.String(), "Authorization")
409
	assert.Contains(t, loggerBuffer.String(), "Bearer 123")
410
}
411
412
func TestScanWithMalformedHeaderShouldErr(t *testing.T) {
413
	const malformedHeader = "gibberish"
414
415
	logger, _ := test.NewLogger()
416
417
	c, err := createCommand(logger)
418
	assert.NoError(t, err)
419
	assert.NotNil(t, c)
420
421
	testServer, serverAssertion := test.NewServerWithAssertion(
422
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
423
	)
424
	defer testServer.Close()
425
426
	_, _, err = executeCommand(
427
		c,
428
		"scan",
429
		testServer.URL,
430
		"--header",
431
		"Accept-Language: en-US,en;q=0.5",
432
		"--header",
433
		malformedHeader,
434
		"--dictionary",
435
		"testdata/dict.txt",
436
	)
437
	assert.Error(t, err)
438
	assert.Contains(t, err.Error(), malformedHeader)
439
	assert.Contains(t, err.Error(), "header is in invalid format")
440
441
	assert.Equal(t, 0, serverAssertion.Len())
442
}
443
444
func TestStartScanWithSocks5ShouldFindResultsWhenAServerIsAvailable(t *testing.T) {
445
	logger, _ := test.NewLogger()
446
447
	c, err := createCommand(logger)
448
	assert.NoError(t, err)
449
	assert.NotNil(t, c)
450
451
	testServer, serverAssertion := test.NewServerWithAssertion(
452
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
453
			w.WriteHeader(http.StatusNotFound)
454
		}),
455
	)
456
	defer testServer.Close()
457
458
	socks5Server := startSocks5TestServer(t)
459
	defer socks5Server.Close()
460
461
	_, _, err = executeCommand(
462
		c,
463
		"scan",
464
		testServer.URL,
465
		"--dictionary",
466
		"testdata/dict.txt",
467
		"-v",
468
		"--http-timeout",
469
		"300",
470
		"--socks5",
471
		socks5TestServerHost,
472
	)
473
	assert.NoError(t, err)
474
475
	assert.Equal(t, 3, serverAssertion.Len())
476
}
477
478
func TestShouldFailToScanWithAnUnreachableSocks5Server(t *testing.T) {
479
	logger, loggerBuffer := test.NewLogger()
480
481
	c, err := createCommand(logger)
482
	assert.NoError(t, err)
483
	assert.NotNil(t, c)
484
485
	testServer, serverAssertion := test.NewServerWithAssertion(
486
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
487
			w.WriteHeader(http.StatusNotFound)
488
		}),
489
	)
490
	defer testServer.Close()
491
492
	socks5Server := startSocks5TestServer(t)
493
	defer socks5Server.Close()
494
495
	_, _, err = executeCommand(
496
		c,
497
		"scan",
498
		testServer.URL,
499
		"--dictionary",
500
		"testdata/dict.txt",
501
		"-v",
502
		"--http-timeout",
503
		"300",
504
		"--socks5",
505
		"127.0.0.1:9555", // invalid
506
	)
507
	assert.NoError(t, err)
508
509
	assert.Equal(t, 0, serverAssertion.Len())
510
	assert.Contains(t, loggerBuffer.String(), "failed to perform request")
511
	assert.Contains(t, loggerBuffer.String(), "socks connect tcp")
512
	assert.Contains(t, loggerBuffer.String(), "connect: connection refused")
513
}
514
515
func TestShouldFailToStartWithAnInvalidSocks5Address(t *testing.T) {
516
	logger, _ := test.NewLogger()
517
518
	c, err := createCommand(logger)
519
	assert.NoError(t, err)
520
	assert.NotNil(t, c)
521
522
	testServer, serverAssertion := test.NewServerWithAssertion(
523
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
524
			w.WriteHeader(http.StatusNotFound)
525
		}),
526
	)
527
	defer testServer.Close()
528
529
	_, _, err = executeCommand(
530
		c,
531
		"scan",
532
		testServer.URL,
533
		"--dictionary",
534
		"testdata/dict.txt",
535
		"-v",
536
		"--http-timeout",
537
		"300",
538
		"--socks5",
539
		"localhost%%2", // invalid
540
	)
541
	assert.Error(t, err)
542
	assert.Contains(t, err.Error(), "invalid URL escape")
543
544
	assert.Equal(t, 0, serverAssertion.Len())
545
}
546
547
func startSocks5TestServer(t *testing.T) net.Listener {
548
	conf := &socks5.Config{}
549
	server, err := socks5.New(conf)
550
	if err != nil {
551
		t.Fatalf("failed to create socks5: %s", err.Error())
552
	}
553
554
	listener, err := net.Listen("tcp", socks5TestServerHost)
555
	if err != nil {
556
		t.Fatalf("failed to create listener: %s", err.Error())
557
	}
558
559
	go func() {
560
		// Create SOCKS5 proxy on localhost port 8000
561
		if err := server.Serve(listener); err != nil {
562
			t.Logf("socks5 stopped serving: %s", err.Error())
563
		}
564
	}()
565
566
	return listener
567
}
568