Passed
Push — master ( 107211...10f7a6 )
by Stefano
01:22
created

cmd.buildDictionary   A

Complexity

Conditions 3

Size

Total Lines 12
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 12
rs 10
c 0
b 0
f 0
1
package cmd
2
3
import (
4
	"context"
5
	"fmt"
6
	"net/http"
7
	"net/url"
8
	"os"
9
	"os/signal"
10
11
	"github.com/pkg/errors"
12
	"github.com/sirupsen/logrus"
13
	"github.com/spf13/cobra"
14
	"github.com/stefanoj3/dirstalk/pkg/cmd/termination"
15
	"github.com/stefanoj3/dirstalk/pkg/common"
16
	"github.com/stefanoj3/dirstalk/pkg/dictionary"
17
	"github.com/stefanoj3/dirstalk/pkg/scan"
18
	"github.com/stefanoj3/dirstalk/pkg/scan/client"
19
	"github.com/stefanoj3/dirstalk/pkg/scan/filter"
20
	"github.com/stefanoj3/dirstalk/pkg/scan/output"
21
	"github.com/stefanoj3/dirstalk/pkg/scan/producer"
22
	"github.com/stefanoj3/dirstalk/pkg/scan/summarizer"
23
	"github.com/stefanoj3/dirstalk/pkg/scan/summarizer/tree"
24
)
25
26
func NewScanCommand(logger *logrus.Logger) *cobra.Command {
27
	cmd := &cobra.Command{
28
		Use:   "scan [url]",
29
		Short: "Scan the given URL",
30
		RunE:  buildScanFunction(logger),
31
	}
32
33
	cmd.Flags().StringP(
34
		flagScanDictionary,
35
		flagScanDictionaryShort,
36
		"",
37
		"dictionary to use for the scan (path to local file or remote url)",
38
	)
39
	common.Must(cmd.MarkFlagFilename(flagScanDictionary))
40
	common.Must(cmd.MarkFlagRequired(flagScanDictionary))
41
42
	cmd.Flags().IntP(
43
		flagScanDictionaryGetTimeout,
44
		"",
45
		50000,
46
		"timeout in milliseconds (used when fetching remote dictionary)",
47
	)
48
49
	cmd.Flags().StringSlice(
50
		flagScanHTTPMethods,
51
		[]string{"GET"},
52
		"comma separated list of http methods to use; eg: GET,POST,PUT",
53
	)
54
55
	cmd.Flags().IntSlice(
56
		flagScanHTTPStatusesToIgnore,
57
		[]int{http.StatusNotFound},
58
		"comma separated list of http statuses to ignore when showing and processing results; eg: 404,301",
59
	)
60
61
	cmd.Flags().IntP(
62
		flagScanThreads,
63
		flagScanThreadsShort,
64
		3,
65
		"amount of threads for concurrent requests",
66
	)
67
68
	cmd.Flags().IntP(
69
		flagScanHTTPTimeout,
70
		"",
71
		5000,
72
		"timeout in milliseconds",
73
	)
74
75
	cmd.Flags().BoolP(
76
		flagScanHTTPCacheRequests,
77
		"",
78
		true,
79
		"cache requests to avoid performing the same request multiple times within the same scan (EG if the "+
80
			"server reply with the same redirect location multiple times, dirstalk will follow it only once)",
81
	)
82
83
	cmd.Flags().IntP(
84
		flagScanScanDepth,
85
		"",
86
		3,
87
		"scan depth",
88
	)
89
90
	cmd.Flags().StringP(
91
		flagScanSocks5Host,
92
		"",
93
		"",
94
		"socks5 host to use",
95
	)
96
97
	cmd.Flags().StringP(
98
		flagScanUserAgent,
99
		"",
100
		"",
101
		"user agent to use for http requests",
102
	)
103
104
	cmd.Flags().BoolP(
105
		flagScanCookieJar,
106
		"",
107
		false,
108
		"enables the use of a cookie jar: it will retain any cookie sent "+
109
			"from the server and send them for the following requests",
110
	)
111
112
	cmd.Flags().StringArray(
113
		flagScanCookie,
114
		[]string{},
115
		"cookie to add to each request; eg name=value (can be specified multiple times)",
116
	)
117
118
	cmd.Flags().StringArray(
119
		flagScanHeader,
120
		[]string{},
121
		"header to add to each request; eg name=value (can be specified multiple times)",
122
	)
123
124
	cmd.Flags().String(
125
		flagScanResultOutput,
126
		"",
127
		"path where to store result output",
128
	)
129
130
	cmd.Flags().Bool(
131
		flagShouldSkipSSLCertificatesValidation,
132
		false,
133
		"to skip checking the validity of SSL certificates",
134
	)
135
136
	return cmd
137
}
138
139
func buildScanFunction(logger *logrus.Logger) func(cmd *cobra.Command, args []string) error {
140
	f := func(cmd *cobra.Command, args []string) error {
141
		u, err := getURL(args)
142
		if err != nil {
143
			return err
144
		}
145
146
		cnf, err := scanConfigFromCmd(cmd)
147
		if err != nil {
148
			return errors.Wrap(err, "failed to build config")
149
		}
150
151
		return startScan(logger, cnf, u)
152
	}
153
154
	return f
155
}
156
157
func getURL(args []string) (*url.URL, error) {
158
	if len(args) == 0 {
159
		return nil, errors.New("no URL provided")
160
	}
161
162
	arg := args[0]
163
164
	u, err := url.ParseRequestURI(arg)
165
	if err != nil {
166
		return nil, errors.Wrap(err, "the first argument must be a valid url")
167
	}
168
169
	return u, nil
170
}
171
172
// startScan is a convenience method that wires together all the dependencies needed to start a scan
173
func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error {
174
	dict, err := buildDictionary(cnf, u)
175
	if err != nil {
176
		return err
177
	}
178
179
	s, err := buildScanner(cnf, dict, u, logger)
180
	if err != nil {
181
		return err
182
	}
183
184
	logger.WithFields(logrus.Fields{
185
		"url":               u.String(),
186
		"threads":           cnf.Threads,
187
		"dictionary-length": len(dict),
188
		"scan-depth":        cnf.ScanDepth,
189
		"timeout":           cnf.TimeoutInMilliseconds,
190
		"socks5":            cnf.Socks5Url,
191
		"cookies":           stringifyCookies(cnf.Cookies),
192
		"cookie-jar":        cnf.UseCookieJar,
193
		"headers":           stringifyHeaders(cnf.Headers),
194
		"user-agent":        cnf.UserAgent,
195
	}).Info("Starting scan")
196
197
	resultSummarizer := summarizer.NewResultSummarizer(tree.NewResultTreeProducer(), logger)
198
199
	osSigint := make(chan os.Signal, 1)
200
	signal.Notify(osSigint, os.Interrupt)
201
202
	outputSaver, err := newOutputSaver(cnf.Out)
203
	if err != nil {
204
		return errors.Wrap(err, "failed to create output saver")
205
	}
206
207
	defer func() {
208
		resultSummarizer.Summarize()
209
		err := outputSaver.Close()
210
		if err != nil {
211
			logger.WithError(err).Error("failed to close output file")
212
		}
213
		logger.Info("Finished scan")
214
	}()
215
216
	ctx, cancellationFunc := context.WithCancel(context.Background())
217
	defer cancellationFunc()
218
219
	resultsChannel := s.Scan(ctx, u, cnf.Threads)
220
221
	terminationHandler := termination.NewTerminationHandler(2)
222
223
	for {
224
		select {
225
		case <-osSigint:
226
			terminationHandler.SignalTermination()
227
			cancellationFunc()
228
229
			if terminationHandler.ShouldTerminate() {
230
				logger.Info("Received sigint, terminating...")
231
				return nil
232
			}
233
234
			logger.Info(
235
				"Received sigint, trying to shutdown gracefully, another SIGNINT will terminate the application",
236
			)
237
		case result, ok := <-resultsChannel:
238
			if !ok {
239
				logger.Debug("result channel is being closed, scan should be complete")
240
				return nil
241
			}
242
243
			resultSummarizer.Add(result)
244
245
			if err := outputSaver.Save(result); err != nil {
246
				return errors.Wrap(err, "failed to add output to file")
247
			}
248
		}
249
	}
250
}
251
252
func buildScanner(cnf *scan.Config, dict []string, u *url.URL, logger *logrus.Logger) (*scan.Scanner, error) {
253
	targetProducer := producer.NewDictionaryProducer(cnf.HTTPMethods, dict, cnf.ScanDepth)
254
	reproducer := producer.NewReProducer(targetProducer)
255
256
	resultFilter := filter.NewHTTPStatusResultFilter(cnf.HTTPStatusesToIgnore)
257
258
	scannerClient, err := buildScannerClient(cnf, u)
259
	if err != nil {
260
		return nil, err
261
	}
262
263
	s := scan.NewScanner(
264
		scannerClient,
265
		targetProducer,
266
		reproducer,
267
		resultFilter,
268
		logger,
269
	)
270
271
	return s, nil
272
}
273
274
func buildDictionary(cnf *scan.Config, u *url.URL) ([]string, error) {
275
	c, err := buildDictionaryClient(cnf, u)
276
	if err != nil {
277
		return nil, err
278
	}
279
280
	dict, err := dictionary.NewDictionaryFrom(cnf.DictionaryPath, c)
281
	if err != nil {
282
		return nil, errors.Wrap(err, "failed to build dictionary")
283
	}
284
285
	return dict, nil
286
}
287
288
func buildScannerClient(cnf *scan.Config, u *url.URL) (*http.Client, error) {
289
	c, err := client.NewClientFromConfig(
290
		cnf.TimeoutInMilliseconds,
291
		cnf.Socks5Url,
292
		cnf.UserAgent,
293
		cnf.UseCookieJar,
294
		cnf.Cookies,
295
		cnf.Headers,
296
		cnf.CacheRequests,
297
		cnf.ShouldSkipSSLCertificatesValidation,
298
		u,
299
	)
300
	if err != nil {
301
		return nil, errors.Wrap(err, "failed to build scanner client")
302
	}
303
304
	return c, nil
305
}
306
307
func buildDictionaryClient(cnf *scan.Config, u *url.URL) (*http.Client, error) {
308
	c, err := client.NewClientFromConfig(
309
		cnf.DictionaryTimeoutInMilliseconds,
310
		cnf.Socks5Url,
311
		cnf.UserAgent,
312
		cnf.UseCookieJar,
313
		cnf.Cookies,
314
		cnf.Headers,
315
		cnf.CacheRequests,
316
		cnf.ShouldSkipSSLCertificatesValidation,
317
		u,
318
	)
319
	if err != nil {
320
		return nil, errors.Wrap(err, "failed to build dictionary client")
321
	}
322
323
	return c, nil
324
}
325
326
func newOutputSaver(path string) (OutputSaver, error) {
327
	if path == "" {
328
		return output.NewNullSaver(), nil
329
	}
330
331
	return output.NewFileSaver(path)
332
}
333
334
func stringifyCookies(cookies []*http.Cookie) string {
335
	result := ""
336
337
	for _, cookie := range cookies {
338
		result += fmt.Sprintf("{%s=%s}", cookie.Name, cookie.Value)
339
	}
340
341
	return result
342
}
343
344
func stringifyHeaders(headers map[string]string) string {
345
	result := ""
346
347
	for name, value := range headers {
348
		result += fmt.Sprintf("{%s:%s}", name, value)
349
	}
350
351
	return result
352
}
353