Passed
Push — master ( 3b5095...f7d107 )
by Stefano
08:05
created

cmd_test.TestScanWithMalformedHeaderShouldErr   A

Complexity

Conditions 2

Size

Total Lines 30
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 23
nop 1
dl 0
loc 30
rs 9.328
c 0
b 0
f 0
1
package cmd_test
2
3
import (
4
	"bytes"
5
	"io/ioutil"
6
	"net/http"
7
	"net/http/httptest"
8
	"os"
9
	"strings"
10
	"sync"
11
	"testing"
12
	"time"
13
14
	"github.com/pkg/errors"
15
	"github.com/sirupsen/logrus"
16
	"github.com/spf13/cobra"
17
	"github.com/stefanoj3/dirstalk/pkg/cmd"
18
	"github.com/stefanoj3/dirstalk/pkg/common/test"
19
	"github.com/stretchr/testify/assert"
20
)
21
22
func TestRootCommand(t *testing.T) {
23
	logger, _ := test.NewLogger()
24
25
	c, err := createCommand(logger)
26
	assert.NoError(t, err)
27
	assert.NotNil(t, c)
28
29
	_, out, err := executeCommand(c)
30
	assert.NoError(t, err)
31
32
	// ensure the summary is printed
33
	assert.Contains(t, out, "dirstalk is a tool that attempts")
34
	assert.Contains(t, out, "Usage")
35
	assert.Contains(t, out, "dictionary.generate")
36
	assert.Contains(t, out, "scan")
37
}
38
39
func TestScanCommand(t *testing.T) {
40
	logger, _ := test.NewLogger()
41
42
	c, err := createCommand(logger)
43
	assert.NoError(t, err)
44
	assert.NotNil(t, c)
45
46
	testServer, serverAssertion := test.NewServerWithAssertion(
47
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48
			w.WriteHeader(http.StatusNotFound)
49
		}),
50
	)
51
	defer testServer.Close()
52
53
	_, _, err = executeCommand(c, "scan", testServer.URL, "--dictionary", "testdata/dict.txt", "-v")
54
	assert.NoError(t, err)
55
56
	assert.Equal(t, 3, serverAssertion.Len())
57
}
58
59
func TestScanWithRemoteDictionary(t *testing.T) {
60
	logger, _ := test.NewLogger()
61
62
	c, err := createCommand(logger)
63
	assert.NoError(t, err)
64
	assert.NotNil(t, c)
65
66
	dictionaryServer := httptest.NewServer(
67
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68
			dict := `home
69
home/index.php
70
blabla
71
`
72
			w.WriteHeader(http.StatusOK)
73
			_, _ = w.Write([]byte(dict))
74
		}),
75
	)
76
	defer dictionaryServer.Close()
77
78
	testServer, serverAssertion := test.NewServerWithAssertion(
79
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80
			w.WriteHeader(http.StatusNotFound)
81
		}),
82
	)
83
	defer testServer.Close()
84
85
	_, _, err = executeCommand(c, "scan", testServer.URL, "--dictionary", dictionaryServer.URL)
86
	assert.NoError(t, err)
87
88
	assert.Equal(t, 3, serverAssertion.Len())
89
}
90
91
func TestScanWithUserAgentFlag(t *testing.T) {
92
	const testUserAgent = "my_test_user_agent"
93
94
	logger, loggerBuffer := test.NewLogger()
95
96
	c, err := createCommand(logger)
97
	assert.NoError(t, err)
98
	assert.NotNil(t, c)
99
100
	testServer, serverAssertion := test.NewServerWithAssertion(
101
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
102
			w.WriteHeader(http.StatusNotFound)
103
		}),
104
	)
105
	defer testServer.Close()
106
107
	_, _, err = executeCommand(
108
		c,
109
		"scan",
110
		testServer.URL,
111
		"--user-agent",
112
		testUserAgent,
113
		"--dictionary",
114
		"testdata/dict.txt",
115
	)
116
	assert.NoError(t, err)
117
118
	assert.Equal(t, 3, serverAssertion.Len())
119
	serverAssertion.Range(func(_ int, r http.Request) {
120
		assert.Equal(t, testUserAgent, r.Header.Get("User-Agent"))
121
	})
122
123
	// to ensure we print the user agent to the cli
124
	assert.Contains(t, loggerBuffer.String(), testUserAgent)
125
}
126
127
func TestScanWithCookies(t *testing.T) {
128
	logger, loggerBuffer := test.NewLogger()
129
130
	c, err := createCommand(logger)
131
	assert.NoError(t, err)
132
	assert.NotNil(t, c)
133
134
	testServer, serverAssertion := test.NewServerWithAssertion(
135
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
136
	)
137
	defer testServer.Close()
138
139
	_, _, err = executeCommand(
140
		c,
141
		"scan",
142
		testServer.URL,
143
		"--cookie",
144
		"name1=val1",
145
		"--cookie",
146
		"name2=val2",
147
		"--dictionary",
148
		"testdata/dict.txt",
149
	)
150
	assert.NoError(t, err)
151
152
	serverAssertion.Range(func(_ int, r http.Request) {
153
		assert.Equal(t, 2, len(r.Cookies()))
154
155
		assert.Equal(t, r.Cookies()[0].Name, "name1")
156
		assert.Equal(t, r.Cookies()[0].Value, "val1")
157
158
		assert.Equal(t, r.Cookies()[1].Name, "name2")
159
		assert.Equal(t, r.Cookies()[1].Value, "val2")
160
	})
161
162
	// to ensure we print the cookies to the cli
163
	assert.Contains(t, loggerBuffer.String(), "name1=val1")
164
	assert.Contains(t, loggerBuffer.String(), "name2=val2")
165
}
166
167
func TestWhenProvidingCookiesInWrongFormatShouldErr(t *testing.T) {
168
	const malformedCookie = "gibberish"
169
170
	logger, _ := test.NewLogger()
171
172
	c, err := createCommand(logger)
173
	assert.NoError(t, err)
174
	assert.NotNil(t, c)
175
176
	testServer, serverAssertion := test.NewServerWithAssertion(
177
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
178
			w.WriteHeader(http.StatusNotFound)
179
		}),
180
	)
181
	defer testServer.Close()
182
183
	_, _, err = executeCommand(
184
		c,
185
		"scan",
186
		testServer.URL,
187
		"--cookie",
188
		malformedCookie,
189
		"--dictionary",
190
		"testdata/dict.txt",
191
	)
192
	assert.Error(t, err)
193
	assert.Contains(t, err.Error(), "cookie format is invalid")
194
	assert.Contains(t, err.Error(), malformedCookie)
195
196
	assert.Equal(t, 0, serverAssertion.Len())
197
}
198
199
func TestScanWithCookieJar(t *testing.T) {
200
	const (
201
		serverCookieName  = "server_cookie_name"
202
		serverCookieValue = "server_cookie_value"
203
	)
204
205
	logger, _ := test.NewLogger()
206
207
	c, err := createCommand(logger)
208
	assert.NoError(t, err)
209
	assert.NotNil(t, c)
210
211
	once := sync.Once{}
212
	testServer, serverAssertion := test.NewServerWithAssertion(
213
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
214
			once.Do(func() {
215
				http.SetCookie(
216
					w,
217
					&http.Cookie{
218
						Name:    serverCookieName,
219
						Value:   serverCookieValue,
220
						Expires: time.Now().AddDate(0, 1, 0),
221
					},
222
				)
223
			})
224
		}),
225
	)
226
	defer testServer.Close()
227
228
	_, _, err = executeCommand(
229
		c,
230
		"scan",
231
		testServer.URL,
232
		"--use-cookie-jar",
233
		"--dictionary",
234
		"testdata/dict.txt",
235
		"-t",
236
		"1",
237
	)
238
	assert.NoError(t, err)
239
240
	serverAssertion.Range(func(index int, r http.Request) {
241
		if index == 0 { // first request should have no cookies
242
			assert.Equal(t, 0, len(r.Cookies()))
243
			return
244
		}
245
246
		assert.Equal(t, 1, len(r.Cookies()))
247
		assert.Equal(t, r.Cookies()[0].Name, serverCookieName)
248
		assert.Equal(t, r.Cookies()[0].Value, serverCookieValue)
249
	})
250
}
251
252
func TestScanWithUnknownFlagShouldErr(t *testing.T) {
253
	logger, _ := test.NewLogger()
254
255
	c, err := createCommand(logger)
256
	assert.NoError(t, err)
257
	assert.NotNil(t, c)
258
259
	testServer, serverAssertion := test.NewServerWithAssertion(
260
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
261
	)
262
	defer testServer.Close()
263
264
	_, _, err = executeCommand(
265
		c,
266
		"scan",
267
		testServer.URL,
268
		"--gibberishflag",
269
		"--dictionary",
270
		"testdata/dict.txt",
271
	)
272
	assert.Error(t, err)
273
	assert.Contains(t, err.Error(), "unknown flag")
274
275
	assert.Equal(t, 0, serverAssertion.Len())
276
}
277
278
func TestScanWithHeaders(t *testing.T) {
279
	logger, loggerBuffer := test.NewLogger()
280
281
	c, err := createCommand(logger)
282
	assert.NoError(t, err)
283
	assert.NotNil(t, c)
284
285
	testServer, serverAssertion := test.NewServerWithAssertion(
286
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
287
	)
288
	defer testServer.Close()
289
290
	_, _, err = executeCommand(
291
		c,
292
		"scan",
293
		testServer.URL,
294
		"--header",
295
		"Accept-Language: en-US,en;q=0.5",
296
		"--header",
297
		`"Authorization: Bearer 123"`,
298
		"--dictionary",
299
		"testdata/dict.txt",
300
	)
301
	assert.NoError(t, err)
302
303
	serverAssertion.Range(func(_ int, r http.Request) {
304
		assert.Equal(t, 2, len(r.Header))
305
306
		assert.Equal(t, "en-US,en;q=0.5", r.Header.Get("Accept-Language"))
307
		assert.Equal(t, "Bearer 123", r.Header.Get("Authorization"))
308
	})
309
310
	// to ensure we print the headers to the cli
311
	assert.Contains(t, loggerBuffer.String(), "Accept-Language")
312
	assert.Contains(t, loggerBuffer.String(), "Authorization")
313
	assert.Contains(t, loggerBuffer.String(), "Bearer 123")
314
}
315
316
func TestScanWithMalformedHeaderShouldErr(t *testing.T) {
317
	const malformedHeader = "gibberish"
318
319
	logger, _ := test.NewLogger()
320
321
	c, err := createCommand(logger)
322
	assert.NoError(t, err)
323
	assert.NotNil(t, c)
324
325
	testServer, serverAssertion := test.NewServerWithAssertion(
326
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
327
	)
328
	defer testServer.Close()
329
330
	_, _, err = executeCommand(
331
		c,
332
		"scan",
333
		testServer.URL,
334
		"--header",
335
		"Accept-Language: en-US,en;q=0.5",
336
		"--header",
337
		malformedHeader,
338
		"--dictionary",
339
		"testdata/dict.txt",
340
	)
341
	assert.Error(t, err)
342
	assert.Contains(t, err.Error(), malformedHeader)
343
	assert.Contains(t, err.Error(), "header is in invalid format")
344
345
	assert.Equal(t, 0, serverAssertion.Len())
346
}
347
348
func TestDictionaryGenerateCommand(t *testing.T) {
349
	logger, _ := test.NewLogger()
350
351
	c, err := createCommand(logger)
352
	assert.NoError(t, err)
353
	assert.NotNil(t, c)
354
355
	testFilePath := "testdata/" + test.RandStringRunes(10)
356
	defer removeTestFile(testFilePath)
357
	_, _, err = executeCommand(c, "dictionary.generate", ".", "-o", testFilePath)
358
	assert.NoError(t, err)
359
360
	content, err := ioutil.ReadFile(testFilePath)
361
	assert.NoError(t, err)
362
363
	// Ensure the command ran and produced some of the expected output
364
	// it is not in the scope of this test to ensure the correct output
365
	assert.Contains(t, string(content), "root_integration_test.go")
366
}
367
368
func TestGenerateDictionaryWithoutOutputPath(t *testing.T) {
369
	logger, _ := test.NewLogger()
370
371
	c, err := createCommand(logger)
372
	assert.NoError(t, err)
373
	assert.NotNil(t, c)
374
375
	_, _, err = executeCommand(c, "dictionary.generate", ".")
376
	assert.NoError(t, err)
377
}
378
379
func TestGenerateDictionaryWithInvalidDirectory(t *testing.T) {
380
	logger, _ := test.NewLogger()
381
382
	fakePath := "./" + test.RandStringRunes(10)
383
	c, err := createCommand(logger)
384
	assert.NoError(t, err)
385
	assert.NotNil(t, c)
386
387
	_, _, err = executeCommand(c, "dictionary.generate", fakePath)
388
	assert.Error(t, err)
389
390
	assert.Contains(t, err.Error(), "unable to use the provided path")
391
	assert.Contains(t, err.Error(), fakePath)
392
}
393
394
func TestVersionCommand(t *testing.T) {
395
	logger, buf := test.NewLogger()
396
397
	c, err := createCommand(logger)
398
	assert.NoError(t, err)
399
	assert.NotNil(t, c)
400
401
	_, _, err = executeCommand(c, "version")
402
	assert.NoError(t, err)
403
404
	// Ensure the command ran and produced some of the expected output
405
	// it is not in the scope of this test to ensure the correct output
406
	assert.Contains(t, buf.String(), "Version: ")
407
}
408
409
func executeCommand(root *cobra.Command, args ...string) (c *cobra.Command, output string, err error) {
410
	buf := new(bytes.Buffer)
411
	root.SetOutput(buf)
412
413
	a := []string{""}
414
	os.Args = append(a, args...)
415
416
	c, err = root.ExecuteC()
417
418
	return c, buf.String(), err
419
}
420
421
func removeTestFile(path string) {
422
	if !strings.Contains(path, "testdata") {
423
		return
424
	}
425
426
	_ = os.Remove(path)
427
}
428
429
func createCommand(logger *logrus.Logger) (*cobra.Command, error) {
430
	dirStalkCmd, err := cmd.NewRootCommand(logger)
431
	if err != nil {
432
		return nil, err
433
	}
434
435
	scanCmd, err := cmd.NewScanCommand(logger)
436
	if err != nil {
437
		return nil, errors.Wrap(err, "failed to create scan command")
438
	}
439
440
	dirStalkCmd.AddCommand(scanCmd)
441
	dirStalkCmd.AddCommand(cmd.NewGenerateDictionaryCommand())
442
	dirStalkCmd.AddCommand(cmd.NewVersionCommand(logger.Out))
443
444
	return dirStalkCmd, nil
445
}
446