package main import ( "context" _ "embed" "encoding/json" "errors" "flag" "fmt" "io" "net" "net/http" "os" "os/signal" "strconv" "strings" "sync" "syscall" "time" "html/template" ) //go:embed templates/index.html var indexHTML string type ServeCommand struct { Port int DBPath string RateLimit int } func (c *ServeCommand) Name() string { return "serve" } func (c *ServeCommand) Init(args []string) error { fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError) fs.IntVar(&c.Port, "port", 8080, "Port to listen on") fs.StringVar(&c.DBPath, "db", "data/abvjt.db", "Path to SQLite database") fs.IntVar(&c.RateLimit, "rate-limit", 30, "Max requests per minute per IP") if err := fs.Parse(args); err != nil { if errors.Is(err, flag.ErrHelp) { fs.Usage() return nil } return err } return nil } func (c *ServeCommand) Run(stdin io.Reader, stdout io.Writer) error { db, err := OpenDB(c.DBPath) if err != nil { return fmt.Errorf("failed to open database: %w", err) } defer db.Close() http.HandleFunc("/", handleRoot()) http.HandleFunc("/api/search", handleSearch(db, c.RateLimit)) http.HandleFunc("/api/health", handleHealth(db, c.RateLimit)) addr := fmt.Sprintf(":%d", c.Port) server := &http.Server{ Addr: addr, ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, } listener, err := net.Listen("tcp", addr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", addr, err) } fmt.Fprintf(stdout, "Server listening on %s (DB: %s, rate limit: %d/min)\n", addr, c.DBPath, c.RateLimit) // Graceful shutdown idleConnsClosed := make(chan struct{}) go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) <-sigint fmt.Fprintln(stdout, "\nShutting down...") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { fmt.Fprintf(os.Stderr, "Shutdown error: %v\n", err) } close(idleConnsClosed) }() go func() { if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) os.Exit(1) } }() <-idleConnsClosed return nil } func handleRoot() http.HandlerFunc { tmpl, err := template.New("index").Parse(indexHTML) if err != nil { panic(fmt.Sprintf("failed to parse template: %v", err)) } return func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") tmpl.Execute(w, nil) } } func handleSearch(db *DB, rateLimit int) http.HandlerFunc { rl := newRateLimiter(rateLimit) return func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } ip := clientIP(r) if !rl.allow(ip) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{"error": "rate limit exceeded"}) return } query := strings.TrimSpace(r.URL.Query().Get("q")) if query == "" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "missing query parameter 'q'"}) return } if len(query) > 200 { query = query[:200] } limitStr := r.URL.Query().Get("limit") limit, err := strconv.Atoi(limitStr) if err != nil || limit <= 0 { limit = 50 } results, err := db.SearchJournals(query, limit) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "search failed"}) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(results) } } func handleHealth(db *DB, rateLimit int) http.HandlerFunc { rl := newRateLimiter(rateLimit) return func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } ip := clientIP(r) if !rl.allow(ip) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{"error": "rate limit exceeded"}) return } count, err := db.Count() status := "ok" if err != nil { status = "error" } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "status": status, "db_loaded": err == nil, "total_journals": count, }) } } func clientIP(r *http.Request) string { ip := r.Header.Get("X-Forwarded-For") if ip != "" { parts := strings.Split(ip, ",") return strings.TrimSpace(parts[0]) } ip = r.Header.Get("X-Real-Ip") if ip != "" { return strings.TrimSpace(ip) } host, _, _ := net.SplitHostPort(r.RemoteAddr) return host } type rateLimiter struct { mu sync.RWMutex buckets map[string]*bucket maxPerMin int } type bucket struct { tokens float64 lastCheck time.Time } func newRateLimiter(maxPerMin int) *rateLimiter { return &rateLimiter{ buckets: make(map[string]*bucket), maxPerMin: maxPerMin, } } func (rl *rateLimiter) allow(ip string) bool { now := time.Now() rl.mu.Lock() defer rl.mu.Unlock() b, ok := rl.buckets[ip] if !ok { rl.buckets[ip] = &bucket{ tokens: float64(rl.maxPerMin) - 1, lastCheck: now, } return true } elapsed := now.Sub(b.lastCheck).Minutes() b.tokens = min(float64(rl.maxPerMin), b.tokens+elapsed*float64(rl.maxPerMin)) b.lastCheck = now if b.tokens >= 1 { b.tokens-- return true } return false }