diff options
| author | Sam Scholten | 2026-06-14 20:00:15 +1000 |
|---|---|---|
| committer | Sam Scholten | 2026-06-14 20:00:15 +1000 |
| commit | decc46c876e7b5552f5f5ecac4ee4f1a64ad1d62 (patch) | |
| tree | 46875e236a062189115c0cd8ed8f1d82980c16b7 /server.go | |
| download | abvjt-main.tar.gz abvjt-main.zip | |
Diffstat (limited to 'server.go')
| -rw-r--r-- | server.go | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/server.go b/server.go new file mode 100644 index 0000000..eeaa42e --- /dev/null +++ b/server.go @@ -0,0 +1,251 @@ +package main + +import ( + "context" + _ "embed" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net" + "net/http" + "os" + "os/signal" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "html/template" +) + +//go:embed templates/index.html +var indexHTML string + +type ServeCommand struct { + Port int + DBPath string + RateLimit int +} + +func (c *ServeCommand) Name() string { return "serve" } + +func (c *ServeCommand) Init(args []string) error { + fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + fs.IntVar(&c.Port, "port", 8080, "Port to listen on") + fs.StringVar(&c.DBPath, "db", "data/abvjt.db", "Path to SQLite database") + fs.IntVar(&c.RateLimit, "rate-limit", 30, "Max requests per minute per IP") + if err := fs.Parse(args); err != nil { + if errors.Is(err, flag.ErrHelp) { + fs.Usage() + return nil + } + return err + } + return nil +} + +func (c *ServeCommand) Run(stdin io.Reader, stdout io.Writer) error { + db, err := OpenDB(c.DBPath) + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + defer db.Close() + + http.HandleFunc("/", handleRoot()) + http.HandleFunc("/api/search", handleSearch(db, c.RateLimit)) + http.HandleFunc("/api/health", handleHealth(db, c.RateLimit)) + + addr := fmt.Sprintf(":%d", c.Port) + server := &http.Server{ + Addr: addr, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + } + + listener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + + fmt.Fprintf(stdout, "Server listening on %s (DB: %s, rate limit: %d/min)\n", addr, c.DBPath, c.RateLimit) + + // Graceful shutdown + idleConnsClosed := make(chan struct{}) + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) + <-sigint + fmt.Fprintln(stdout, "\nShutting down...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + fmt.Fprintf(os.Stderr, "Shutdown error: %v\n", err) + } + close(idleConnsClosed) + }() + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + fmt.Fprintf(os.Stderr, "Server error: %v\n", err) + os.Exit(1) + } + }() + + <-idleConnsClosed + return nil +} + +func handleRoot() http.HandlerFunc { + tmpl, err := template.New("index").Parse(indexHTML) + if err != nil { + panic(fmt.Sprintf("failed to parse template: %v", err)) + } + return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + tmpl.Execute(w, nil) + } +} + +func handleSearch(db *DB, rateLimit int) http.HandlerFunc { + rl := newRateLimiter(rateLimit) + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ip := clientIP(r) + if !rl.allow(ip) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{"error": "rate limit exceeded"}) + return + } + + query := strings.TrimSpace(r.URL.Query().Get("q")) + if query == "" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{"error": "missing query parameter 'q'"}) + return + } + + if len(query) > 200 { + query = query[:200] + } + + limitStr := r.URL.Query().Get("limit") + limit, err := strconv.Atoi(limitStr) + if err != nil || limit <= 0 { + limit = 50 + } + + results, err := db.SearchJournals(query, limit) + if err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]string{"error": "search failed"}) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(results) + } +} + +func handleHealth(db *DB, rateLimit int) http.HandlerFunc { + rl := newRateLimiter(rateLimit) + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ip := clientIP(r) + if !rl.allow(ip) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{"error": "rate limit exceeded"}) + return + } + + count, err := db.Count() + status := "ok" + if err != nil { + status = "error" + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": status, + "db_loaded": err == nil, + "total_journals": count, + }) + } +} + +func clientIP(r *http.Request) string { + ip := r.Header.Get("X-Forwarded-For") + if ip != "" { + parts := strings.Split(ip, ",") + return strings.TrimSpace(parts[0]) + } + ip = r.Header.Get("X-Real-Ip") + if ip != "" { + return strings.TrimSpace(ip) + } + host, _, _ := net.SplitHostPort(r.RemoteAddr) + return host +} + +type rateLimiter struct { + mu sync.RWMutex + buckets map[string]*bucket + maxPerMin int +} + +type bucket struct { + tokens float64 + lastCheck time.Time +} + +func newRateLimiter(maxPerMin int) *rateLimiter { + return &rateLimiter{ + buckets: make(map[string]*bucket), + maxPerMin: maxPerMin, + } +} + +func (rl *rateLimiter) allow(ip string) bool { + now := time.Now() + + rl.mu.Lock() + defer rl.mu.Unlock() + + b, ok := rl.buckets[ip] + if !ok { + rl.buckets[ip] = &bucket{ + tokens: float64(rl.maxPerMin) - 1, + lastCheck: now, + } + return true + } + + elapsed := now.Sub(b.lastCheck).Minutes() + b.tokens = min(float64(rl.maxPerMin), b.tokens+elapsed*float64(rl.maxPerMin)) + b.lastCheck = now + + if b.tokens >= 1 { + b.tokens-- + return true + } + return false +} |
