Passed
Push — master ( a2eed9...a44a1f )
by Abouzar
06:50
created

main.TestCopyContentWithSharedLimiter   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 1
dl 0
loc 16
rs 9.85
c 0
b 0
f 0
1
package main
2
3
import (
4
	"context"
5
	"net/http"
6
	"net/http/httptest"
7
	"os"
8
	"os/user"
9
	"path/filepath"
10
	"strings"
11
	"testing"
12
	"time"
13
14
	"golang.org/x/time/rate"
15
)
16
17
func TestPartCalculate(t *testing.T) {
18
	// Disable progress bar for tests
19
	displayProgress = false
20
21
	// Setup test environment
22
	originalDataFolder := dataFolder
23
	dataFolder = ".hget_test/"
24
	defer func() {
25
		dataFolder = originalDataFolder
26
		usr, _ := user.Current()
27
		testFolder := filepath.Join(usr.HomeDir, dataFolder)
28
		os.RemoveAll(testFolder)
29
	}()
30
31
	// Test with different numbers of parts
32
	testCases := []struct {
33
		parts       int64
34
		totalSize   int64
35
		url         string
36
		expectParts int
37
	}{
38
		{10, 100, "http://foo.bar/file", 10},
39
		{5, 1000, "http://example.com/largefile", 5},
40
		{1, 50, "http://test.org/smallfile", 1},
41
		{3, 10, "http://tiny.file/data", 3}, // Small file, multiple parts
42
	}
43
44
	for _, tc := range testCases {
45
		parts := partCalculate(tc.parts, tc.totalSize, tc.url)
46
47
		// Check number of parts
48
		if len(parts) != tc.expectParts {
49
			t.Errorf("Expected %d parts, got %d", tc.expectParts, len(parts))
50
		}
51
52
		// Check part URLs
53
		for i, part := range parts {
54
			if part.URL != tc.url {
55
				t.Errorf("Part %d: Expected URL %s, got %s", i, tc.url, part.URL)
56
			}
57
58
			// Check part index
59
			if part.Index != int64(i) {
60
				t.Errorf("Part %d: Expected Index %d, got %d", i, i, part.Index)
61
			}
62
63
			// Check ranges
64
			expectedSize := tc.totalSize / tc.parts
65
			if i < int(tc.parts-1) {
66
				if part.RangeFrom != expectedSize*int64(i) {
67
					t.Errorf("Part %d: Expected RangeFrom %d, got %d",
68
						i, expectedSize*int64(i), part.RangeFrom)
69
				}
70
				if part.RangeTo != expectedSize*int64(i+1)-1 {
71
					t.Errorf("Part %d: Expected RangeTo %d, got %d",
72
						i, expectedSize*int64(i+1)-1, part.RangeTo)
73
				}
74
			} else {
75
				// Last part might be larger due to division remainder
76
				if part.RangeFrom != expectedSize*int64(i) {
77
					t.Errorf("Part %d: Expected RangeFrom %d, got %d",
78
						i, expectedSize*int64(i), part.RangeFrom)
79
				}
80
				if part.RangeTo != tc.totalSize {
81
					t.Errorf("Part %d: Expected RangeTo %d, got %d",
82
						i, tc.totalSize, part.RangeTo)
83
				}
84
			}
85
86
			// Check path format
87
			usr, _ := user.Current()
88
			expectedBasePath := filepath.Join(usr.HomeDir, dataFolder)
89
			if !strings.Contains(part.Path, expectedBasePath) {
90
				t.Errorf("Part %d: Path does not contain expected base path: %s", i, part.Path)
91
			}
92
93
			fileName := filepath.Base(part.Path)
94
			expectedPrefix := TaskFromURL(tc.url) + ".part"
95
			if !strings.HasPrefix(fileName, expectedPrefix) {
96
				t.Errorf("Part %d: Expected filename prefix %s, got %s",
97
					i, expectedPrefix, fileName)
98
			}
99
		}
100
	}
101
}
102
103
func TestProxyAwareHTTPClient(t *testing.T) {
104
	// Test with no proxy
105
	client := ProxyAwareHTTPClient("", false)
106
	if client == nil {
107
		t.Fatal("ProxyAwareHTTPClient returned nil with no proxy")
108
	}
109
110
	// Cannot easily test with an actual proxy, but can verify it doesn't crash
111
	httpProxyClient := ProxyAwareHTTPClient("http://localhost:8080", false)
112
	if httpProxyClient == nil {
113
		t.Fatal("ProxyAwareHTTPClient returned nil with HTTP proxy")
114
	}
115
116
	socksProxyClient := ProxyAwareHTTPClient("localhost:1080", false)
117
	if socksProxyClient == nil {
118
		t.Fatal("ProxyAwareHTTPClient returned nil with SOCKS proxy")
119
	}
120
121
	// Test TLS skipVerify parameter
122
	tlsClient := ProxyAwareHTTPClient("", true)
123
	if tlsClient == nil {
124
		t.Fatal("ProxyAwareHTTPClient returned nil with TLS skip verification")
125
	}
126
127
	// Can't directly access TLS config, but it shouldn't crash
128
}
129
130
// Helper function to parse integers
131
func parseInt(s string) int {
132
	n := 0
133
	for _, c := range s {
134
		n = n*10 + int(c-'0')
135
	}
136
	return n
137
}
138
139
func TestHandleCompletedPart(t *testing.T) {
140
	// Disable progress bar for tests
141
	displayProgress = false
142
143
	// Create a part that's already complete
144
	part := Part{
145
		Index:     0,
146
		URL:       "http://example.com/test",
147
		Path:      "test.part000000",
148
		RangeFrom: 100, // RangeFrom equals RangeTo means no data to download
149
		RangeTo:   100,
150
	}
151
152
	// Create channels
153
	fileChan := make(chan string, 1)
154
	stateSaveChan := make(chan Part, 1)
155
156
	// Create downloader
157
	downloader := &HTTPDownloader{
158
		url:       "http://example.com/test",
159
		file:      "test",
160
		par:       1,
161
		len:       100,
162
		parts:     []Part{part},
163
		resumable: true,
164
	}
165
166
	// Handle the completed part
167
	downloader.handleCompletedPart(part, fileChan, stateSaveChan)
168
169
	// Verify the path was sent to fileChan
170
	select {
171
	case path := <-fileChan:
172
		if path != part.Path {
173
			t.Errorf("Expected path %s, got %s", part.Path, path)
174
		}
175
	default:
176
		t.Errorf("No path sent to fileChan")
177
	}
178
179
	// Verify the part was sent to stateSaveChan
180
	select {
181
	case savedPart := <-stateSaveChan:
182
		if savedPart.Index != part.Index ||
183
			savedPart.URL != part.URL ||
184
			savedPart.Path != part.Path ||
185
			savedPart.RangeFrom != part.RangeFrom ||
186
			savedPart.RangeTo != part.RangeTo {
187
			t.Errorf("Saved part does not match original part")
188
		}
189
	default:
190
		t.Errorf("No part sent to stateSaveChan")
191
	}
192
}
193
194
func TestBuildRequestForPart(t *testing.T) {
195
	// Test cases for different range situations
196
	testCases := []struct {
197
		description string
198
		part        Part
199
		contentLen  int64
200
		parallelism int64
201
		expected    string
202
	}{
203
		{
204
			description: "Single connection download (no range)",
205
			part: Part{
206
				Index:     0,
207
				URL:       "http://example.com/file",
208
				Path:      "file.part000000",
209
				RangeFrom: 0,
210
				RangeTo:   100,
211
			},
212
			contentLen:  100,
213
			parallelism: 1,
214
			expected:    "", // No range header expected
215
		},
216
		{
217
			description: "Multiple connection download with middle part",
218
			part: Part{
219
				Index:     1,
220
				URL:       "http://example.com/file",
221
				Path:      "file.part000001",
222
				RangeFrom: 50,
223
				RangeTo:   99,
224
			},
225
			contentLen:  200,
226
			parallelism: 3,
227
			expected:    "bytes=50-99",
228
		},
229
		{
230
			description: "Multiple connection download with last part",
231
			part: Part{
232
				Index:     2,
233
				URL:       "http://example.com/file",
234
				Path:      "file.part000002",
235
				RangeFrom: 100,
236
				RangeTo:   200,
237
			},
238
			contentLen:  200,
239
			parallelism: 3,
240
			expected:    "bytes=100-",
241
		},
242
	}
243
244
	for _, tc := range testCases {
245
		t.Run(tc.description, func(t *testing.T) {
246
			// Create downloader
247
			downloader := &HTTPDownloader{
248
				url:       tc.part.URL,
249
				file:      "file",
250
				par:       tc.parallelism,
251
				len:       tc.contentLen,
252
				parts:     []Part{tc.part},
253
				resumable: true,
254
			}
255
256
			// Build request
257
			req, err := downloader.buildRequestForPart(context.Background(), tc.part)
258
			if err != nil {
259
				t.Fatalf("buildRequestForPart failed: %v", err)
260
			}
261
262
			// Check range header
263
			rangeHeader := req.Header.Get("Range")
264
			if tc.expected == "" {
265
				if rangeHeader != "" {
266
					t.Errorf("Expected no Range header, got '%s'", rangeHeader)
267
				}
268
			} else {
269
				if rangeHeader != tc.expected {
270
					t.Errorf("Expected Range header '%s', got '%s'", tc.expected, rangeHeader)
271
				}
272
			}
273
274
			// Check URL
275
			if req.URL.String() != tc.part.URL {
276
				t.Errorf("Expected URL %s, got %s", tc.part.URL, req.URL.String())
277
			}
278
		})
279
	}
280
}
281
282
func TestCopyContent(t *testing.T) {
283
	// Create test data
284
	testData := "This is test data for copy content"
285
286
	// Test regular copy (no rate limit)
287
	t.Run("No Rate Limit", func(t *testing.T) {
288
		// Create source and destination
289
		src := strings.NewReader(testData)
290
		var dst strings.Builder
291
292
		// Create downloader with no rate limit
293
		downloader := &HTTPDownloader{
294
			rate: 0,
295
		}
296
297
		// Copy content
298
		done := make(chan bool)
299
		go downloader.copyContent(src, &dst, done)
300
301
		// Wait for completion
302
		<-done
303
304
		// Verify copied content
305
		if dst.String() != testData {
306
			t.Errorf("Expected content '%s', got '%s'", testData, dst.String())
307
		}
308
	})
309
310
	// Test copy with rate limit (can only test that it doesn't crash)
311
	t.Run("With Rate Limit", func(t *testing.T) {
312
		// Create source and destination
313
		src := strings.NewReader(testData)
314
		var dst strings.Builder
315
316
		// Create downloader with rate limit
317
		downloader := &HTTPDownloader{
318
			rate: 1024, // 1KB/s
319
		}
320
321
		// Copy content
322
		done := make(chan bool)
323
		go downloader.copyContent(src, &dst, done)
324
325
		// Wait for completion with timeout
326
		select {
327
		case <-done:
328
			// Success
329
		case <-time.After(2 * time.Second):
330
			t.Fatalf("Copy with rate limit timed out")
331
		}
332
333
		// Verify copied content
334
		if dst.String() != testData {
335
			t.Errorf("Expected content '%s', got '%s'", testData, dst.String())
336
		}
337
	})
338
}
339
340
func TestCopyContentWithSharedLimiter(t *testing.T) {
341
	// Verify copyContent path using sharedLimiter copies data correctly.
342
	testData := "shared limiter content"
343
	src := strings.NewReader(testData)
344
	var dst strings.Builder
345
346
	downloader := &HTTPDownloader{
347
		sharedLimiter: rate.NewLimiter(rate.Limit(1<<20), 1<<20), // high limit to avoid slowness
348
	}
349
350
	done := make(chan bool)
351
	go downloader.copyContent(src, &dst, done)
352
	<-done
353
354
	if dst.String() != testData {
355
		t.Errorf("Expected content '%s', got '%s'", testData, dst.String())
356
	}
357
}
358
359
func TestProxyAwareHTTPClientConfiguration(t *testing.T) {
360
	// HTTP proxy config should set Transport.Proxy
361
	client := ProxyAwareHTTPClient("http://127.0.0.1:3128", false)
362
	tr, ok := client.Transport.(*http.Transport)
363
	if !ok {
364
		t.Fatalf("Transport is not *http.Transport")
365
	}
366
	if tr.Proxy == nil {
367
		t.Errorf("expected HTTP proxy to be configured on Transport.Proxy")
368
	}
369
370
	// SOCKS5 proxy config should set DialContext
371
	client = ProxyAwareHTTPClient("127.0.0.1:1080", false)
372
	tr, ok = client.Transport.(*http.Transport)
373
	if !ok || tr.DialContext == nil {
374
		t.Errorf("expected SOCKS5 DialContext to be configured")
375
	}
376
377
	// Ensure timeouts are set
378
	if tr.TLSHandshakeTimeout == 0 || tr.ResponseHeaderTimeout == 0 || tr.ExpectContinueTimeout == 0 {
379
		t.Errorf("expected transport timeouts to be set")
380
	}
381
}
382
383
func TestNewHTTPDownloaderProbe(t *testing.T) {
384
	// HEAD returns Accept-Ranges and Content-Length
385
	content := strings.Repeat("x", 1024)
386
	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
387
		switch r.Method {
388
		case http.MethodHead:
389
			w.Header().Set("Accept-Ranges", "bytes")
390
			w.Header().Set("Content-Length", "1024")
391
			w.WriteHeader(http.StatusOK)
392
		case http.MethodGet:
393
			w.WriteHeader(http.StatusOK)
394
			_, _ = w.Write([]byte(content))
395
		}
396
	})
397
	ts := httptest.NewServer(h)
398
	defer ts.Close()
399
400
	d := NewHTTPDownloader(ts.URL, 4, false, "", "")
401
	if d.par != 4 {
402
		t.Fatalf("expected par=4, got %d", d.par)
403
	}
404
	if d.len != 1024 {
405
		t.Fatalf("expected len=1024, got %d", d.len)
406
	}
407
	if !d.resumable {
408
		t.Fatalf("expected resumable=true")
409
	}
410
	if len(d.parts) != 4 {
411
		t.Fatalf("expected 4 parts, got %d", len(d.parts))
412
	}
413
}
414
415
func TestNewHTTPDownloaderRangeFallback(t *testing.T) {
416
	// HEAD has no Accept-Ranges/Content-Length. GET with Range returns 206 + Content-Range
417
	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
418
		switch r.Method {
419
		case http.MethodHead:
420
			w.WriteHeader(http.StatusOK)
421
		case http.MethodGet:
422
			if rng := r.Header.Get("Range"); strings.HasPrefix(rng, "bytes=") {
423
				w.Header().Set("Content-Range", "bytes 0-0/4096")
424
				w.WriteHeader(http.StatusPartialContent)
425
				_, _ = w.Write([]byte("x"))
426
				return
427
			}
428
			w.WriteHeader(http.StatusOK)
429
			_, _ = w.Write([]byte(strings.Repeat("x", 4096)))
430
		}
431
	})
432
	ts := httptest.NewServer(h)
433
	defer ts.Close()
434
435
	d := NewHTTPDownloader(ts.URL, 4, false, "", "")
436
	if d.par != 4 {
437
		t.Fatalf("expected par=4, got %d", d.par)
438
	}
439
	if d.len != 4096 {
440
		t.Fatalf("expected len=4096, got %d", d.len)
441
	}
442
}
443