Passed
Pull Request — master (#41)
by Stefano
02:35
created

cmd_test.executeCommand   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 2
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
package cmd_test
2
3
import (
4
	"bytes"
5
	"io/ioutil"
6
	"net/http"
7
	"net/http/httptest"
8
	"os"
9
	"strings"
10
	"sync"
11
	"testing"
12
	"time"
13
14
	"github.com/pkg/errors"
15
	"github.com/sirupsen/logrus"
16
	"github.com/spf13/cobra"
17
	"github.com/stefanoj3/dirstalk/pkg/cmd"
18
	"github.com/stefanoj3/dirstalk/pkg/common/test"
19
	"github.com/stretchr/testify/assert"
20
)
21
22
func TestRootCommand(t *testing.T) {
23
	logger, _ := test.NewLogger()
24
25
	c, err := createCommand(logger)
26
	assert.NoError(t, err)
27
	assert.NotNil(t, c)
28
29
	_, out, err := executeCommand(c)
30
	assert.NoError(t, err)
31
32
	// ensure the summary is printed
33
	assert.Contains(t, out, "dirstalk is a tool that attempts")
34
	assert.Contains(t, out, "Usage")
35
	assert.Contains(t, out, "dictionary.generate")
36
	assert.Contains(t, out, "scan")
37
}
38
39
func TestScanCommand(t *testing.T) {
40
	logger, _ := test.NewLogger()
41
42
	c, err := createCommand(logger)
43
	assert.NoError(t, err)
44
	assert.NotNil(t, c)
45
46
	testServer, serverAssertion := test.NewServerWithAssertion(
47
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48
			w.WriteHeader(http.StatusNotFound)
49
		}),
50
	)
51
	defer testServer.Close()
52
53
	_, _, err = executeCommand(c, "scan", testServer.URL, "--dictionary", "testdata/dict.txt", "-v")
54
	assert.NoError(t, err)
55
56
	assert.Equal(t, 3, serverAssertion.Len())
57
}
58
59
func TestScanWithRemoteDictionary(t *testing.T) {
60
	logger, _ := test.NewLogger()
61
62
	c, err := createCommand(logger)
63
	assert.NoError(t, err)
64
	assert.NotNil(t, c)
65
66
	dictionaryServer := httptest.NewServer(
67
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68
			dict := `home
69
home/index.php
70
blabla
71
`
72
			w.WriteHeader(http.StatusOK)
73
			_, _ = w.Write([]byte(dict))
74
		}),
75
	)
76
	defer dictionaryServer.Close()
77
78
	testServer, serverAssertion := test.NewServerWithAssertion(
79
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80
			w.WriteHeader(http.StatusNotFound)
81
		}),
82
	)
83
	defer testServer.Close()
84
85
	_, _, err = executeCommand(c, "scan", testServer.URL, "--dictionary", dictionaryServer.URL)
86
	assert.NoError(t, err)
87
88
	assert.Equal(t, 3, serverAssertion.Len())
89
}
90
91
func TestScanWithUserAgentFlag(t *testing.T) {
92
	const testUserAgent = "my_test_user_agent"
93
94
	logger, _ := test.NewLogger()
95
96
	c, err := createCommand(logger)
97
	assert.NoError(t, err)
98
	assert.NotNil(t, c)
99
100
	testServer, serverAssertion := test.NewServerWithAssertion(
101
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
102
			w.WriteHeader(http.StatusNotFound)
103
		}),
104
	)
105
	defer testServer.Close()
106
107
	_, _, err = executeCommand(
108
		c,
109
		"scan",
110
		testServer.URL,
111
		"--user-agent",
112
		testUserAgent,
113
		"--dictionary",
114
		"testdata/dict.txt",
115
	)
116
	assert.NoError(t, err)
117
118
	assert.Equal(t, 3, serverAssertion.Len())
119
	serverAssertion.Range(func(_ int, r http.Request) {
120
		assert.Equal(t, testUserAgent, r.Header.Get("User-Agent"))
121
	})
122
}
123
124
func TestScanWithCookies(t *testing.T) {
125
	logger, loggerBuffer := test.NewLogger()
126
127
	c, err := createCommand(logger)
128
	assert.NoError(t, err)
129
	assert.NotNil(t, c)
130
131
	testServer, serverAssertion := test.NewServerWithAssertion(
132
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
133
	)
134
	defer testServer.Close()
135
136
	_, _, err = executeCommand(
137
		c,
138
		"scan",
139
		testServer.URL,
140
		"--cookie",
141
		"name1=val1",
142
		"--cookie",
143
		"name2=val2",
144
		"--dictionary",
145
		"testdata/dict.txt",
146
	)
147
	assert.NoError(t, err)
148
149
	serverAssertion.Range(func(_ int, r http.Request) {
150
		assert.Equal(t, 2, len(r.Cookies()))
151
152
		assert.Equal(t, r.Cookies()[0].Name, "name1")
153
		assert.Equal(t, r.Cookies()[0].Value, "val1")
154
155
		assert.Equal(t, r.Cookies()[1].Name, "name2")
156
		assert.Equal(t, r.Cookies()[1].Value, "val2")
157
	})
158
159
	// to ensure we print the headers to the cli
160
	assert.Contains(t, loggerBuffer.String(), "name1=val1")
161
	assert.Contains(t, loggerBuffer.String(), "name2=val2")
162
}
163
164
func TestWhenProvidingCookiesInWrongFormatShouldErr(t *testing.T) {
165
	const malformedCookie = "gibberish"
166
167
	logger, _ := test.NewLogger()
168
169
	c, err := createCommand(logger)
170
	assert.NoError(t, err)
171
	assert.NotNil(t, c)
172
173
	testServer, serverAssertion := test.NewServerWithAssertion(
174
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
175
			w.WriteHeader(http.StatusNotFound)
176
		}),
177
	)
178
	defer testServer.Close()
179
180
	_, _, err = executeCommand(
181
		c,
182
		"scan",
183
		testServer.URL,
184
		"--cookie",
185
		malformedCookie,
186
		"--dictionary",
187
		"testdata/dict.txt",
188
	)
189
	assert.Error(t, err)
190
	assert.Contains(t, err.Error(), "cookie format is invalid")
191
	assert.Contains(t, err.Error(), malformedCookie)
192
193
	assert.Equal(t, 0, serverAssertion.Len())
194
}
195
196
func TestScanWithCookieJar(t *testing.T) {
197
	const (
198
		serverCookieName  = "server_cookie_name"
199
		serverCookieValue = "server_cookie_value"
200
	)
201
202
	logger, _ := test.NewLogger()
203
204
	c, err := createCommand(logger)
205
	assert.NoError(t, err)
206
	assert.NotNil(t, c)
207
208
	once := sync.Once{}
209
	testServer, serverAssertion := test.NewServerWithAssertion(
210
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
211
			once.Do(func() {
212
				http.SetCookie(
213
					w,
214
					&http.Cookie{
215
						Name:    serverCookieName,
216
						Value:   serverCookieValue,
217
						Expires: time.Now().AddDate(0, 1, 0),
218
					},
219
				)
220
			})
221
		}),
222
	)
223
	defer testServer.Close()
224
225
	_, _, err = executeCommand(
226
		c,
227
		"scan",
228
		testServer.URL,
229
		"--use-cookie-jar",
230
		"--dictionary",
231
		"testdata/dict.txt",
232
		"-t",
233
		"1",
234
	)
235
	assert.NoError(t, err)
236
237
	serverAssertion.Range(func(index int, r http.Request) {
238
		if index == 0 { // first request should have no cookies
239
			assert.Equal(t, 0, len(r.Cookies()))
240
			return
241
		}
242
243
		assert.Equal(t, 1, len(r.Cookies()))
244
		assert.Equal(t, r.Cookies()[0].Name, serverCookieName)
245
		assert.Equal(t, r.Cookies()[0].Value, serverCookieValue)
246
	})
247
}
248
249
func TestScanWithUnknownFlagShouldErr(t *testing.T) {
250
	logger, _ := test.NewLogger()
251
252
	c, err := createCommand(logger)
253
	assert.NoError(t, err)
254
	assert.NotNil(t, c)
255
256
	testServer, serverAssertion := test.NewServerWithAssertion(
257
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
258
	)
259
	defer testServer.Close()
260
261
	_, _, err = executeCommand(
262
		c,
263
		"scan",
264
		testServer.URL,
265
		"--gibberishflag",
266
		"--dictionary",
267
		"testdata/dict.txt",
268
	)
269
	assert.Error(t, err)
270
	assert.Contains(t, err.Error(), "unknown flag")
271
272
	assert.Equal(t, 0, serverAssertion.Len())
273
}
274
275
func TestScanWithHeaders(t *testing.T) {
276
	logger, loggerBuffer := test.NewLogger()
277
278
	c, err := createCommand(logger)
279
	assert.NoError(t, err)
280
	assert.NotNil(t, c)
281
282
	testServer, serverAssertion := test.NewServerWithAssertion(
283
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
284
	)
285
	defer testServer.Close()
286
287
	_, _, err = executeCommand(
288
		c,
289
		"scan",
290
		testServer.URL,
291
		"--header",
292
		"Accept-Language: en-US,en;q=0.5",
293
		"--header",
294
		`"Authorization: Bearer 123"`,
295
		"--dictionary",
296
		"testdata/dict.txt",
297
	)
298
	assert.NoError(t, err)
299
300
	serverAssertion.Range(func(_ int, r http.Request) {
301
		assert.Equal(t, 2, len(r.Header))
302
303
		assert.Equal(t, "en-US,en;q=0.5", r.Header.Get("Accept-Language"))
304
		assert.Equal(t, "Bearer 123", r.Header.Get("Authorization"))
305
	})
306
307
	// to ensure we print the headers to the cli
308
	assert.Contains(t, loggerBuffer.String(), "Accept-Language")
309
	assert.Contains(t, loggerBuffer.String(), "Authorization")
310
	assert.Contains(t, loggerBuffer.String(), "Bearer 123")
311
}
312
313
func TestDictionaryGenerateCommand(t *testing.T) {
314
	logger, _ := test.NewLogger()
315
316
	c, err := createCommand(logger)
317
	assert.NoError(t, err)
318
	assert.NotNil(t, c)
319
320
	testFilePath := "testdata/" + test.RandStringRunes(10)
321
	defer removeTestFile(testFilePath)
322
	_, _, err = executeCommand(c, "dictionary.generate", ".", "-o", testFilePath)
323
	assert.NoError(t, err)
324
325
	content, err := ioutil.ReadFile(testFilePath)
326
	assert.NoError(t, err)
327
328
	// Ensure the command ran and produced some of the expected output
329
	// it is not in the scope of this test to ensure the correct output
330
	assert.Contains(t, string(content), "root_integration_test.go")
331
}
332
333
func TestGenerateDictionaryWithoutOutputPath(t *testing.T) {
334
	logger, _ := test.NewLogger()
335
336
	c, err := createCommand(logger)
337
	assert.NoError(t, err)
338
	assert.NotNil(t, c)
339
340
	_, _, err = executeCommand(c, "dictionary.generate", ".")
341
	assert.NoError(t, err)
342
}
343
344
func TestGenerateDictionaryWithInvalidDirectory(t *testing.T) {
345
	logger, _ := test.NewLogger()
346
347
	fakePath := "./" + test.RandStringRunes(10)
348
	c, err := createCommand(logger)
349
	assert.NoError(t, err)
350
	assert.NotNil(t, c)
351
352
	_, _, err = executeCommand(c, "dictionary.generate", fakePath)
353
	assert.Error(t, err)
354
355
	assert.Contains(t, err.Error(), "unable to use the provided path")
356
	assert.Contains(t, err.Error(), fakePath)
357
}
358
359
func TestVersionCommand(t *testing.T) {
360
	logger, buf := test.NewLogger()
361
362
	c, err := createCommand(logger)
363
	assert.NoError(t, err)
364
	assert.NotNil(t, c)
365
366
	_, _, err = executeCommand(c, "version")
367
	assert.NoError(t, err)
368
369
	// Ensure the command ran and produced some of the expected output
370
	// it is not in the scope of this test to ensure the correct output
371
	assert.Contains(t, buf.String(), "Version: ")
372
}
373
374
func executeCommand(root *cobra.Command, args ...string) (c *cobra.Command, output string, err error) {
375
	buf := new(bytes.Buffer)
376
	root.SetOutput(buf)
377
378
	a := []string{""}
379
	os.Args = append(a, args...)
380
381
	c, err = root.ExecuteC()
382
383
	return c, buf.String(), err
384
}
385
386
func removeTestFile(path string) {
387
	if !strings.Contains(path, "testdata") {
388
		return
389
	}
390
391
	_ = os.Remove(path)
392
}
393
394
func createCommand(logger *logrus.Logger) (*cobra.Command, error) {
395
	dirStalkCmd, err := cmd.NewRootCommand(logger)
396
	if err != nil {
397
		return nil, err
398
	}
399
400
	scanCmd, err := cmd.NewScanCommand(logger)
401
	if err != nil {
402
		return nil, errors.Wrap(err, "failed to create scan command")
403
	}
404
405
	dirStalkCmd.AddCommand(scanCmd)
406
	dirStalkCmd.AddCommand(cmd.NewGenerateDictionaryCommand())
407
	dirStalkCmd.AddCommand(cmd.NewVersionCommand(logger.Out))
408
409
	return dirStalkCmd, nil
410
}
411