package main import ( "context" "fmt" "net/http" "net/http/httptest" "testing" "time" ) type TestLogger struct { messages []string } func (l *TestLogger) Printf(format string, v ...interface{}) { l.messages = append(l.messages, fmt.Sprintf(format, v...)) } func TestHTTPClientDefaults(t *testing.T) { client := NewHTTPClient() if client.userAgent != "scholfetch/1.0 (+https://samsci.com)" { t.Errorf("Expected default user agent, got %s", client.userAgent) } if client.arxivDelay != 1*time.Second { t.Errorf("Expected arxiv delay of 1s, got %v", client.arxivDelay) } if client.maxRetries != 3 { t.Errorf("Expected max retries of 3, got %d", client.maxRetries) } } func TestRateLimiting(t *testing.T) { client := NewHTTPClient() client.arxivDelay = 10 * time.Millisecond // Speed up test client.s2Delay = 5 * time.Millisecond // Test arxiv rate limiting start := time.Now() err := client.RateLimitArxiv(context.Background()) if err != nil { t.Fatalf("RateLimitArxiv failed: %v", err) } duration := time.Since(start) if duration < 10*time.Millisecond { t.Errorf("Expected arxiv delay of ~10ms, got %v", duration) } // Test S2 rate limiting start = time.Now() err = client.RateLimitS2(context.Background()) if err != nil { t.Fatalf("RateLimitS2 failed: %v", err) } duration = time.Since(start) if duration < 5*time.Millisecond { t.Errorf("Expected S2 delay of ~5ms, got %v", duration) } } func TestHTTPIPRequest(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("test response")) })) defer server.Close() client := NewHTTPClient() req, _ := http.NewRequest("GET", server.URL, nil) resp, err := client.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Errorf("Expected status 200, got %d", resp.StatusCode) } } func TestURLRouting(t *testing.T) { tests := map[string]string{ "https://arxiv.org/abs/2301.00001": "arxiv", "https://arxiv.org/pdf/2301.00001.pdf": "arxiv", "http://arxiv.org/abs/2301.00001v2": "arxiv", "https://api.semanticscholar.org/DOI:10.1234": "rawhtml", "https://doi.org/10.1234/abcd5678": "s2", "https://example.com/paper": "rawhtml", "https://pubmed.ncbi.nlm.nih.gov/12345678/": "rawhtml", } for url, expected := range tests { result := Route(url) if result != expected { t.Errorf("Route(%s) = %s, expected %s", url, result, expected) } } } func TestConfigDefaults(t *testing.T) { config := NewConfig() if config.WithContent != false { t.Error("Expected WithContent=false by default") } if config.Verbose != false { t.Error("Expected Verbose=false by default") } if config.ArxivBatch != 50 { t.Errorf("Expected ArxivBatch=50, got %d", config.ArxivBatch) } if config.HTTP == nil { t.Error("Expected HTTP client to be initialized") } } func TestConfigWithLogger(t *testing.T) { logger := &TestLogger{} config := NewConfigWithLogger(logger) if config.Logger != logger { t.Error("Logger not set correctly") } } func TestArxivURLParsing(t *testing.T) { tests := map[string]string{ "https://arxiv.org/abs/2301.00001": "2301.00001", "http://arxiv.org/abs/2301.00001v2": "2301.00001v2", "https://arxiv.org/pdf/2301.00001.pdf": "2301.00001", "https://example.com/not-arxiv": "", } for url, expected := range tests { result, _ := getArxivIdentifier(url) if result != expected { t.Errorf("getArxivIdentifier(%s) = %s, expected %s", url, result, expected) } } } func TestDOIParsing(t *testing.T) { tests := map[string]string{ "https://doi.org/10.1234/abcd5678": "10.1234/abcd5678", "https://api.semanticscholar.org/DOI:10.1234": "", "https://example.com/no-doi": "", } for url, expected := range tests { result := getDOI(url) if result == expected { t.Logf("✓ getDOI(%s) = %s", url, result) } else { t.Errorf("getDOI(%s) = %s, expected %s", url, result, expected) } } } func TestBatchURLRouting(t *testing.T) { urls := []string{ "https://arxiv.org/abs/2301.00001", "https://doi.org/10.1234/test1", "https://example.com/paper1", "https://arxiv.org/pdf/2301.00002.pdf", "https://doi.org/10.5678/test2", } routeCounts := make(map[string]int) for _, url := range urls { route := Route(url) routeCounts[route]++ } expected := map[string]int{ "arxiv": 2, "s2": 2, "rawhtml": 1, } for route, expectedCount := range expected { if routeCounts[route] != expectedCount { t.Errorf("Expected %d URLs for route %s, got %d", expectedCount, route, routeCounts[route]) } } }