aboutsummaryrefslogtreecommitdiff
path: root/scholfetch_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'scholfetch_test.go')
-rw-r--r--scholfetch_test.go193
1 files changed, 193 insertions, 0 deletions
diff --git a/scholfetch_test.go b/scholfetch_test.go
new file mode 100644
index 0000000..59adae7
--- /dev/null
+++ b/scholfetch_test.go
@@ -0,0 +1,193 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+type TestLogger struct {
+ messages []string
+}
+
+func (l *TestLogger) Printf(format string, v ...interface{}) {
+ l.messages = append(l.messages, fmt.Sprintf(format, v...))
+}
+
+func TestHTTPClientDefaults(t *testing.T) {
+ client := NewHTTPClient()
+
+ if client.userAgent != "scholfetch/1.0 (+https://samsci.com)" {
+ t.Errorf("Expected default user agent, got %s", client.userAgent)
+ }
+
+ if client.arxivDelay != 1*time.Second {
+ t.Errorf("Expected arxiv delay of 1s, got %v", client.arxivDelay)
+ }
+
+ if client.maxRetries != 3 {
+ t.Errorf("Expected max retries of 3, got %d", client.maxRetries)
+ }
+}
+
+func TestRateLimiting(t *testing.T) {
+ client := NewHTTPClient()
+ client.arxivDelay = 10 * time.Millisecond // Speed up test
+ client.s2Delay = 5 * time.Millisecond
+
+ // Test arxiv rate limiting
+ start := time.Now()
+ err := client.RateLimitArxiv(context.Background())
+ if err != nil {
+ t.Fatalf("RateLimitArxiv failed: %v", err)
+ }
+ duration := time.Since(start)
+ if duration < 10*time.Millisecond {
+ t.Errorf("Expected arxiv delay of ~10ms, got %v", duration)
+ }
+
+ // Test S2 rate limiting
+ start = time.Now()
+ err = client.RateLimitS2(context.Background())
+ if err != nil {
+ t.Fatalf("RateLimitS2 failed: %v", err)
+ }
+ duration = time.Since(start)
+ if duration < 5*time.Millisecond {
+ t.Errorf("Expected S2 delay of ~5ms, got %v", duration)
+ }
+}
+
+func TestHTTPIPRequest(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("test response"))
+ }))
+ defer server.Close()
+
+ client := NewHTTPClient()
+ req, _ := http.NewRequest("GET", server.URL, nil)
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("Request failed: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ t.Errorf("Expected status 200, got %d", resp.StatusCode)
+ }
+}
+
+func TestURLRouting(t *testing.T) {
+ tests := map[string]string{
+ "https://arxiv.org/abs/2301.00001": "arxiv",
+ "https://arxiv.org/pdf/2301.00001.pdf": "arxiv",
+ "http://arxiv.org/abs/2301.00001v2": "arxiv",
+ "https://api.semanticscholar.org/DOI:10.1234": "rawhtml",
+ "https://doi.org/10.1234/abcd5678": "s2",
+ "https://example.com/paper": "rawhtml",
+ "https://pubmed.ncbi.nlm.nih.gov/12345678/": "rawhtml",
+ }
+
+ for url, expected := range tests {
+ result := Route(url)
+ if result != expected {
+ t.Errorf("Route(%s) = %s, expected %s", url, result, expected)
+ }
+ }
+}
+
+func TestConfigDefaults(t *testing.T) {
+ config := NewConfig()
+
+ if config.WithContent != false {
+ t.Error("Expected WithContent=false by default")
+ }
+
+ if config.Verbose != false {
+ t.Error("Expected Verbose=false by default")
+ }
+
+ if config.ArxivBatch != 50 {
+ t.Errorf("Expected ArxivBatch=50, got %d", config.ArxivBatch)
+ }
+
+ if config.HTTP == nil {
+ t.Error("Expected HTTP client to be initialized")
+ }
+}
+
+func TestConfigWithLogger(t *testing.T) {
+ logger := &TestLogger{}
+ config := NewConfigWithLogger(logger)
+
+ if config.Logger != logger {
+ t.Error("Logger not set correctly")
+ }
+}
+
+func TestArxivURLParsing(t *testing.T) {
+ tests := map[string]string{
+ "https://arxiv.org/abs/2301.00001": "2301.00001",
+ "http://arxiv.org/abs/2301.00001v2": "2301.00001v2",
+ "https://arxiv.org/pdf/2301.00001.pdf": "2301.00001",
+ "https://example.com/not-arxiv": "",
+ }
+
+ for url, expected := range tests {
+ result, _ := getArxivIdentifier(url)
+ if result != expected {
+ t.Errorf("getArxivIdentifier(%s) = %s, expected %s", url, result, expected)
+ }
+ }
+}
+
+func TestDOIParsing(t *testing.T) {
+ tests := map[string]string{
+ "https://doi.org/10.1234/abcd5678": "10.1234/abcd5678",
+ "https://api.semanticscholar.org/DOI:10.1234": "",
+ "https://example.com/no-doi": "",
+ }
+
+ for url, expected := range tests {
+ result := getDOI(url)
+ if result == expected {
+ t.Logf("✓ getDOI(%s) = %s", url, result)
+ } else {
+ t.Errorf("getDOI(%s) = %s, expected %s", url, result, expected)
+ }
+ }
+}
+
+func TestBatchURLRouting(t *testing.T) {
+ urls := []string{
+ "https://arxiv.org/abs/2301.00001",
+ "https://doi.org/10.1234/test1",
+ "https://example.com/paper1",
+ "https://arxiv.org/pdf/2301.00002.pdf",
+ "https://doi.org/10.5678/test2",
+ }
+
+ routeCounts := make(map[string]int)
+ for _, url := range urls {
+ route := Route(url)
+ routeCounts[route]++
+ }
+
+ expected := map[string]int{
+ "arxiv": 2,
+ "s2": 2,
+ "rawhtml": 1,
+ }
+
+ for route, expectedCount := range expected {
+ if routeCounts[route] != expectedCount {
+ t.Errorf("Expected %d URLs for route %s, got %d",
+ expectedCount, route, routeCounts[route])
+ }
+ }
+} \ No newline at end of file