aboutsummaryrefslogtreecommitdiff
path: root/cmds/train.go
diff options
context:
space:
mode:
authorSam Scholten2025-12-15 19:34:17 +1000
committerSam Scholten2025-12-15 19:34:59 +1000
commit9f5978186ac3de07f4325975fecf4f538fe713b6 (patch)
tree41440b703054fe59eb561ba81d80fd60380c1f7a /cmds/train.go
downloadscholscan-9f5978186ac3de07f4325975fecf4f538fe713b6.tar.gz
scholscan-9f5978186ac3de07f4325975fecf4f538fe713b6.zip
Init v0.1.0
Diffstat (limited to 'cmds/train.go')
-rw-r--r--cmds/train.go841
1 files changed, 841 insertions, 0 deletions
diff --git a/cmds/train.go b/cmds/train.go
new file mode 100644
index 0000000..e7e8915
--- /dev/null
+++ b/cmds/train.go
@@ -0,0 +1,841 @@
+// Train command learns model from positive examples and RSS feeds.
+// Loads positives, fetches RSS feeds as negatives, excludes overlap,
+// trains TF-IDF + logistic regression with 1:1 class balancing.
+// Outputs model with validation threshold to stdout.
+package cmds
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "math"
+ "math/rand"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/mmcdole/gofeed"
+ "scholscan/core"
+)
+
+// ============================================================================
+// ┏━╸┏┳┓╺┳┓ ┏━┓┏┓ ┏┓
+// ┃ ┃┃┃ ┃┃ ┃ ┃┣┻┓ ┃
+// ┗━╸╹ ╹╺┻┛ ┗━┛┗━┛┗━┛
+// ============================================================================
+
+// Learns model from positive examples and RSS feeds
+// Outputs trained model JSON to stdout
+type TrainCommand struct {
+ positivesFile string
+ rssFeedsFile string
+ verboseOutput bool
+ lambda float64
+ minDF int
+ maxDF float64
+ ngramMax int
+}
+
+func (c *TrainCommand) Name() string { return "train" }
+
+func (c *TrainCommand) Init(args []string) error {
+ fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
+ fs.Usage = func() {
+ fmt.Fprint(fs.Output(), `Usage: scholscan train POSITIVES_FILE --rss-feeds RSS_FEEDS_FILE > model.json
+
+Train a TF-IDF + logistic regression model from positive examples and RSS feeds.
+
+The training workflow:
+ 1. Load positive examples from POSITIVES_FILE
+ 2. Fetch articles from RSS feeds list
+ 3. Exclude any positive examples from RSS feed articles
+ 4. Train model with balanced classes
+ 5. Output trained model to stdout as JSON
+
+Flags:
+`)
+ fs.PrintDefaults()
+ fmt.Fprint(fs.Output(), `
+Arguments:
+ POSITIVES_FILE Path to JSONL file with positive examples (required)
+
+Example:
+ scholscan train positives.jsonl --rss-feeds rss_world.txt > model.json
+`)
+ }
+
+ fs.StringVar(&c.rssFeedsFile, "rss-feeds", "", "Path to text file with RSS feed URLs (required)")
+ fs.BoolVar(&c.verboseOutput, "verbose", false, "Show progress information")
+ fs.Float64Var(&c.lambda, "lambda", 0.001, "L2 regularization parameter for logistic regression")
+ fs.IntVar(&c.minDF, "min-df", 2, "Minimum document frequency (absolute count)")
+ fs.Float64Var(&c.maxDF, "max-df", 0.8, "Maximum document frequency (ratio, 0-1)")
+ fs.IntVar(&c.ngramMax, "ngram-max", 2, "Maximum n-gram size (e.g., 1=unigrams, 2=unigrams+bigrams)")
+
+ // Check for help flag first
+ for _, arg := range args {
+ if arg == "--help" || arg == "-h" {
+ fs.Usage()
+ return flag.ErrHelp
+ }
+ }
+
+ // Extract positional argument (POSITIVES_FILE) before parsing flags
+ if len(args) == 0 {
+ return fmt.Errorf("POSITIVES_FILE argument is required")
+ }
+ // The first argument should be the positives file, the rest are flags
+ c.positivesFile = args[0]
+ flagArgs := args[1:]
+
+ if err := fs.Parse(flagArgs); err != nil {
+ return err
+ }
+
+ if c.rssFeedsFile == "" {
+ return fmt.Errorf("--rss-feeds flag is required")
+ }
+
+ // Validate paths are safe (prevent directory traversal)
+ if strings.Contains(filepath.Clean(c.positivesFile), "..") {
+ return fmt.Errorf("invalid positives file path: directory traversal not allowed")
+ }
+ if strings.Contains(filepath.Clean(c.rssFeedsFile), "..") {
+ return fmt.Errorf("invalid RSS feeds file path: directory traversal not allowed")
+ }
+
+ return nil
+}
+
+func (c *TrainCommand) Run(stdin io.Reader, stdout io.Writer) error {
+ if c.verboseOutput {
+ log.SetOutput(os.Stderr)
+ log.Println("Starting training workflow...")
+ log.Printf("Positives: %s", c.positivesFile)
+ log.Printf("RSS feeds: %s", c.rssFeedsFile)
+ }
+
+ if c.verboseOutput {
+ log.Printf("Loading positives from %s...", c.positivesFile)
+ }
+ positives, err := c.loadArticles(c.positivesFile)
+ if err != nil {
+ return fmt.Errorf("failed to load positives: %w", err)
+ }
+ if c.verboseOutput {
+ log.Printf("Loaded %d positive examples", len(positives))
+ }
+
+ if c.verboseOutput {
+ log.Printf("Loading RSS feeds from %s...", c.rssFeedsFile)
+ }
+ rssURLs, err := c.loadRSSURLs(c.rssFeedsFile)
+ if err != nil {
+ return fmt.Errorf("failed to load RSS feeds: %w", err)
+ }
+ if c.verboseOutput {
+ log.Printf("Found %d RSS feeds to fetch", len(rssURLs))
+ }
+
+ negatives, err := c.fetchFromRSSFeeds(rssURLs)
+ if err != nil {
+ return fmt.Errorf("failed to fetch from RSS feeds: %w", err)
+ }
+ if c.verboseOutput {
+ log.Printf("Fetched %d articles from RSS feeds", len(negatives))
+ }
+
+ negatives = c.excludePositives(negatives, positives)
+ if c.verboseOutput {
+ log.Printf("After exclusion: %d negative examples", len(negatives))
+ }
+
+ if len(positives) == 0 || len(negatives) == 0 {
+ return fmt.Errorf("need both positive (%d) and negative (%d) examples for training", len(positives), len(negatives))
+ }
+
+ if c.verboseOutput {
+ log.Println("Training model...")
+ }
+ model, err := c.trainModel(positives, negatives)
+ if err != nil {
+ return fmt.Errorf("failed to train model: %w", err)
+ }
+
+ // Output model
+ encoder := json.NewEncoder(stdout)
+ encoder.SetIndent("", " ")
+ if err := encoder.Encode(model); err != nil {
+ return fmt.Errorf("failed to write model: %w", err)
+ }
+
+ return nil
+}
+
+// ============================================================================
+// ╺┳┓┏━┓╺┳╸┏━┓ ╻ ┏━┓┏━┓╺┳┓╻┏┓╻┏━╸
+// ┃┃┣━┫ ┃ ┣━┫ ┃ ┃ ┃┣━┫ ┃┃┃┃┗┫┃╺┓
+// ╺┻┛╹ ╹ ╹ ╹ ╹ ┗━╸┗━┛╹ ╹╺┻┛╹╹ ╹┗━┛
+// ============================================================================
+
+func (c *TrainCommand) loadArticles(filename string) ([]*core.Article, error) {
+ file, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+
+ var articles []*core.Article
+ decoder := json.NewDecoder(file)
+ lineCount := 0
+ for {
+ var article core.Article
+ if err := decoder.Decode(&article); err != nil {
+ if err == io.EOF {
+ break
+ }
+ // Skip malformed json lines, don't fail on bad input.
+ lineCount++
+ continue
+ }
+ articles = append(articles, &article)
+ lineCount++
+ if lineCount%500 == 0 && c.verboseOutput {
+ log.Printf(" Loaded %d articles so far", len(articles))
+ }
+ }
+ return articles, nil
+}
+
+// loadRSSURLs loads RSS feed URLs from a text file
+func (c *TrainCommand) loadRSSURLs(filename string) ([]string, error) {
+ file, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+
+ var urls []string
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line != "" && !strings.HasPrefix(line, "#") {
+ urls = append(urls, line)
+ }
+ }
+ return urls, scanner.Err()
+}
+
+// fetchFromRSSFeeds fetches articles from multiple RSS feeds in parallel
+func (c *TrainCommand) fetchFromRSSFeeds(rssURLs []string) ([]*core.Article, error) {
+ client := core.DefaultHTTPClient
+ type result struct {
+ url string
+ articles []*core.Article
+ err error
+ }
+ resultChan := make(chan result, len(rssURLs))
+
+ for _, rssURL := range rssURLs {
+ go func(url string) {
+ articles, err := c.fetchRSSFeed(client, url)
+ resultChan <- result{url: url, articles: articles, err: err}
+ }(rssURL)
+ }
+
+ var allArticles []*core.Article
+ for i := 0; i < len(rssURLs); i++ {
+ res := <-resultChan
+ if res.err != nil {
+ if c.verboseOutput {
+ log.Printf("%s: failed to fetch", shortURL(res.url))
+ }
+ } else {
+ if c.verboseOutput {
+ log.Printf("%s: %d articles", shortURL(res.url), len(res.articles))
+ }
+ allArticles = append(allArticles, res.articles...)
+ }
+ }
+
+ return allArticles, nil
+}
+
+// ParseRSSFeed parses an RSS/Atom feed from the provided body into a slice of Articles.
+func ParseRSSFeed(body []byte, baseURL string) ([]*core.Article, error) {
+ fp := gofeed.NewParser()
+ feed, err := fp.Parse(bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+
+ var articles []*core.Article
+ for _, item := range feed.Items {
+ // Prefer explicit content; fall back to description.
+ content := strings.TrimSpace(item.Content)
+ if content == "" {
+ content = item.Description
+ }
+ // Also check custom content field (for <content> tags in RSS)
+ if content == "" && item.Custom != nil {
+ if c, ok := item.Custom["content"]; ok && c != "" {
+ content = c
+ }
+ }
+
+ // Clean and limit content length
+ content = core.CleanFeedContent(content)
+
+ articles = append(articles, &core.Article{
+ URL: item.Link,
+ Title: item.Title,
+ Content: content,
+ })
+ }
+ return articles, nil
+}
+
+// fetchRSSFeed fetches and parses a single RSS feed
+func (c *TrainCommand) fetchRSSFeed(client *http.Client, rssURL string) ([]*core.Article, error) {
+ var body []byte
+ var err error
+
+ // Handle file:// URLs locally
+ if strings.HasPrefix(rssURL, "file://") {
+ // Remove file:// prefix
+ filePath := strings.TrimPrefix(rssURL, "file://")
+ body, err = os.ReadFile(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("error reading file %s: %w", filePath, err)
+ }
+ } else {
+ // Handle HTTP/HTTPS URLs normally
+ req, err := http.NewRequest("GET", rssURL, nil)
+ if err != nil {
+ return nil, fmt.Errorf("error building request: %w", err)
+ }
+ req.Header.Set("User-Agent", core.PoliteUserAgent)
+
+ // Make request with retry logic
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ resp, err := core.DoRequestWithRetry(ctx, client, req)
+ if err != nil {
+ return nil, fmt.Errorf("error fetching %s: %w", rssURL, err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, rssURL)
+ }
+
+ // Read response body
+ body, err = io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("error reading response from %s: %w", rssURL, err)
+ }
+ }
+
+ // Parse RSS/Atom feed
+ return ParseRSSFeed(body, rssURL)
+}
+
+// ============================================================================
+// ╺┳┓┏━┓╺┳╸┏━┓ ┏━┓┏━┓┏━╸┏━┓
+// ┃┃┣━┫ ┃ ┣━┫ ┣━┛┣┳┛┣╸ ┣━┛
+// ╺┻┛╹ ╹ ╹ ╹ ╹ ╹ ╹┗╸┗━╸╹
+// ============================================================================
+
+func (c *TrainCommand) excludePositives(negatives, positives []*core.Article) []*core.Article {
+ // Build set of positive URLs for O(1) lookup
+ positiveURLs := make(map[string]bool)
+ for _, pos := range positives {
+ positiveURLs[pos.URL] = true
+ }
+
+ // Filter out positives
+ var filtered []*core.Article
+ for _, neg := range negatives {
+ if !positiveURLs[neg.URL] {
+ filtered = append(filtered, neg)
+ }
+ }
+
+ return filtered
+}
+
+// splitTrainingData performs a deterministic 80/20 split (seed=42).
+// Deterministic ensures reproducible model training across runs.
+func (c *TrainCommand) splitTrainingData(documents []string, labels []float64) (
+ trainDocs, valDocs []string,
+ trainLabels, valLabels []float64,
+) {
+ const validationSplitRatio = 0.2
+ const splitSeed = 42
+
+ if len(documents) < 3 {
+ // Not enough data to split, use all for training.
+ // A split requires at least 2 training documents to avoid MaxDF issues
+ // and at least 1 validation document.
+ return documents, nil, labels, nil
+ }
+
+ // Create a reproducible random source and shuffle indices.
+ rng := rand.New(rand.NewSource(splitSeed))
+ indices := make([]int, len(documents))
+ for i := range indices {
+ indices[i] = i
+ }
+ rng.Shuffle(len(indices), func(i, j int) {
+ indices[i], indices[j] = indices[j], indices[i]
+ })
+
+ splitIndex := int(float64(len(documents)) * (1.0 - validationSplitRatio))
+ trainIndices := indices[:splitIndex]
+ valIndices := indices[splitIndex:]
+
+ trainDocs = make([]string, len(trainIndices))
+ trainLabels = make([]float64, len(trainIndices))
+ for i, idx := range trainIndices {
+ trainDocs[i] = documents[idx]
+ trainLabels[i] = labels[idx]
+ }
+
+ valDocs = make([]string, len(valIndices))
+ valLabels = make([]float64, len(valIndices))
+ for i, idx := range valIndices {
+ valDocs[i] = documents[idx]
+ valLabels[i] = labels[idx]
+ }
+
+ return trainDocs, valDocs, trainLabels, valLabels
+}
+
+// Downsample majority class to 1:1 ratio AFTER vectorizer.Fit() to preserve IDF values.
+func (c *TrainCommand) downsampleToBalance(docs []string, labels []float64) ([]string, []float64) {
+ // Count positives and negatives
+ var posDocs, negDocs []string
+ var posLabels, negLabels []float64
+
+ for i, label := range labels {
+ if label == 1.0 {
+ posDocs = append(posDocs, docs[i])
+ posLabels = append(posLabels, label)
+ } else {
+ negDocs = append(negDocs, docs[i])
+ negLabels = append(negLabels, label)
+ }
+ }
+
+ // If already balanced, return as-is
+ if len(posDocs) == len(negDocs) {
+ return docs, labels
+ }
+
+ // Determine which class is majority
+ var majorityDocs, minorityDocs []string
+ var majorityLabels, minorityLabels []float64
+
+ if len(negDocs) > len(posDocs) {
+ // Negatives are majority
+ majorityDocs, minorityDocs = negDocs, posDocs
+ majorityLabels, minorityLabels = negLabels, posLabels
+ } else {
+ // Positives are majority (unlikely but handle)
+ majorityDocs, minorityDocs = posDocs, negDocs
+ majorityLabels, minorityLabels = posLabels, negLabels
+ }
+
+ // Downsample majority to match minority size
+ minoritySize := len(minorityDocs)
+ rng := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
+
+ // Create random indices for downsampling
+ indices := make([]int, len(majorityDocs))
+ for i := range indices {
+ indices[i] = i
+ }
+ rng.Shuffle(len(indices), func(i, j int) {
+ indices[i], indices[j] = indices[j], indices[i]
+ })
+
+ // Select downsampled majority
+ downsampledDocs := make([]string, 0, minoritySize*2)
+ downsampledLabels := make([]float64, 0, minoritySize*2)
+
+ // Add all minority samples
+ downsampledDocs = append(downsampledDocs, minorityDocs...)
+ downsampledLabels = append(downsampledLabels, minorityLabels...)
+
+ // Add downsampled majority
+ for i := 0; i < minoritySize; i++ {
+ idx := indices[i]
+ downsampledDocs = append(downsampledDocs, majorityDocs[idx])
+ downsampledLabels = append(downsampledLabels, majorityLabels[idx])
+ }
+
+ return downsampledDocs, downsampledLabels
+}
+
+// ============================================================================
+// ╺┳╸┏━┓┏━┓╻┏┓╻ ┏┳┓┏━┓╺┳┓┏━╸╻
+// ┃ ┣┳┛┣━┫┃┃┗┫ ┃┃┃┃ ┃ ┃┃┣╸ ┃
+// ╹ ╹┗╸╹ ╹╹╹ ╹ ╹ ╹┗━┛╺┻┛┗━╸┗━╸
+// ============================================================================
+
+// trainModel trains a TF-IDF + logistic regression model
+func (c *TrainCommand) trainModel(positives, negatives []*core.Article) (*core.ModelEnvelope, error) {
+ // Combine datasets and create labels
+ var documents []string
+ var labels []float64
+
+ // Process positives
+ for _, article := range positives {
+ // Skip articles with titles that are too short
+ if len(article.Title) < 15 {
+ continue
+ }
+ documents = append(documents, article.Title)
+ labels = append(labels, 1.0)
+ }
+
+ // Process negatives
+ for _, article := range negatives {
+ // Skip articles with titles that are too short
+ if len(article.Title) < 15 {
+ continue
+ }
+ documents = append(documents, article.Title)
+ labels = append(labels, 0.0)
+ }
+
+ // Use parameters from CLI flags (with defaults matching Julia implementation)
+ const vocabCap = 50000
+
+ // Deterministic 80/20 split for train/validation
+ trainDocs, valDocs, trainLabels, valLabels := c.splitTrainingData(documents, labels)
+
+ // Create TF-IDF vectorizer with the specified parameters
+ vectorizer := &core.TFIDFVectorizer{
+ NgramMin: 1,
+ NgramMax: c.ngramMax,
+ MinDF: c.minDF,
+ MaxDF: c.maxDF,
+ VocabCap: vocabCap,
+ Vocabulary: make(map[string]float64),
+ }
+ // Fit vectorizer on UNBALANCED training data to match Julia implementation
+ // This preserves document frequencies properly
+ vectorizer.Fit(trainDocs)
+
+ // Downsample negatives to 1:1 ratio AFTER fitting (match Julia approach)
+ balancedTrainDocs, balancedTrainLabels := c.downsampleToBalance(trainDocs, trainLabels)
+
+ // Transform both training and validation sets
+ trainVectors := vectorizer.Transform(balancedTrainDocs)
+ valVectors := vectorizer.Transform(valDocs)
+
+ // Use uniform class weights since we've balanced the dataset
+ classWeights := map[float64]float64{
+ 1.0: 1.0,
+ 0.0: 1.0,
+ }
+
+ // Train logistic regression with the specified lambda parameter
+ lr := &core.LogisticRegression{
+ LearningRate: 0.5,
+ Lambda: c.lambda,
+ Iterations: 500,
+ Tolerance: 0.000001,
+ }
+ lr.Validate()
+ weights, err := lr.Fit(trainVectors, balancedTrainLabels, classWeights)
+ if err != nil {
+ return nil, fmt.Errorf("failed to train logistic regression model: %w", err)
+ }
+
+ // Find the best threshold on the validation set
+ recommendedThreshold, scoreDistributions := c.findBestThreshold(valVectors, valLabels, weights)
+
+ // Count classes for metadata
+ var posCount, negCount float64
+ for _, label := range labels {
+ if label == 1.0 {
+ posCount++
+ } else {
+ negCount++
+ }
+ }
+
+ // Create model envelope
+ model := &core.ModelEnvelope{
+ Algorithm: "tfidf-go",
+ Impl: "go",
+ Version: "1",
+ CreatedAt: time.Now().UTC(),
+ Meta: map[string]any{
+ "positives": len(positives),
+ "negatives": len(negatives),
+ "class_counts": map[string]int{
+ "pos": int(posCount),
+ "neg": int(negCount),
+ },
+ "vectorizer_params": map[string]any{
+ "ngram_min": vectorizer.NgramMin,
+ "ngram_max": vectorizer.NgramMax,
+ "min_df": vectorizer.MinDF,
+ "max_df": vectorizer.MaxDF,
+ "vocab_cap": vectorizer.VocabCap,
+ },
+ "model_params": map[string]any{
+ "learning_rate": lr.LearningRate,
+ "lambda": lr.Lambda,
+ "iterations": lr.Iterations,
+ "tolerance": lr.Tolerance,
+ },
+ "recommended_threshold": recommendedThreshold,
+ "score_distributions": scoreDistributions,
+ },
+ Vectorizer: vectorizer.Vocabulary,
+ OrderedVocab: vectorizer.OrderedVocab,
+ Weights: weights,
+ }
+
+ return model, nil
+}
+
+// ============================================================================
+// ┏┳┓┏━╸╺┳╸┏━┓╻┏━╸┏━┓
+// ┃┃┃┣╸ ┃ ┣┳┛┃┃ ┗━┓
+// ╹ ╹┗━╸ ╹ ╹┗╸╹┗━╸┗━┛
+// ============================================================================
+
+// ClassificationMetrics holds the evaluation metrics
+type ClassificationMetrics struct {
+ TruePositives int
+ TrueNegatives int
+ FalsePositives int
+ FalseNegatives int
+ Accuracy float64
+ Precision float64
+ Recall float64
+ F1Score float64
+}
+
+// Calculate computes the metrics from raw counts
+func (m *ClassificationMetrics) Calculate() {
+ total := m.TruePositives + m.TrueNegatives + m.FalsePositives + m.FalseNegatives
+
+ if total > 0 {
+ m.Accuracy = float64(m.TruePositives+m.TrueNegatives) / float64(total)
+ }
+
+ if m.TruePositives+m.FalsePositives > 0 {
+ m.Precision = float64(m.TruePositives) / float64(m.TruePositives+m.FalsePositives)
+ }
+
+ if m.TruePositives+m.FalseNegatives > 0 {
+ m.Recall = float64(m.TruePositives) / float64(m.TruePositives+m.FalseNegatives)
+ }
+
+ if m.Precision+m.Recall > 0 {
+ m.F1Score = 2 * (m.Precision * m.Recall) / (m.Precision + m.Recall)
+ }
+}
+
+// findBestThreshold sweeps a range of thresholds on a validation set to find
+// the one that maximizes combined F1 + separation score.
+func (c *TrainCommand) findBestThreshold(
+ validationVectors [][]float64,
+ validationLabels []float64,
+ weights []float64,
+) (float64, map[string]any) {
+ if len(validationVectors) == 0 {
+ return 0.5, nil // Default if no validation data
+ }
+
+ scores := make([]float64, len(validationVectors))
+ for i, vector := range validationVectors {
+ score, err := core.PredictScore(vector, weights)
+ if err != nil {
+ // This should not happen with valid data, but as a fallback:
+ return 0.5, nil
+ }
+ scores[i] = score
+ }
+
+ // Collect score distributions by label
+ var posScores, negScores []float64
+ for i, score := range scores {
+ if validationLabels[i] == 1.0 {
+ posScores = append(posScores, score)
+ } else {
+ negScores = append(negScores, score)
+ }
+ }
+
+ // Compute stats for each class
+ posStats := computeScoreStats(posScores)
+ negStats := computeScoreStats(negScores)
+
+ // Calculate Cohen's d (effect size) to measure class separation in the learned space
+ posMean := posStats["mean"]
+ negMean := negStats["mean"]
+ posStd := posStats["std"]
+ negStd := negStats["std"]
+
+ var cohensD float64
+ if posStd > 0 && negStd > 0 {
+ pooledStd := math.Sqrt((posStd*posStd + negStd*negStd) / 2)
+ cohensD = math.Abs(posMean-negMean) / pooledStd
+ }
+
+ // Calculate separation ratio to understand how much the classes overlap on the score scale
+ totalRange := math.Max(posStats["max"], negStats["max"]) - math.Min(posStats["min"], negStats["min"])
+ overlapStart := math.Max(posStats["min"], negStats["min"])
+ overlapEnd := math.Min(posStats["max"], negStats["max"])
+ overlapRange := math.Max(0, overlapEnd-overlapStart)
+ separationRatio := 0.0
+ if totalRange > 0 {
+ separationRatio = (totalRange - overlapRange) / totalRange
+ }
+
+ // Find threshold that balances false positives and false negatives using Youden's J.
+ // This metric (Sensitivity + Specificity - 1) equally weights both false positive
+ // and false negative rates. Why not F1? F1 biases toward precision when classes
+ // are imbalanced; a validation set of 10 positives and 1000 negatives would push
+ // the threshold too high. Youden's J treats both types of error equally, which
+ // better reflects real use: missing a relevant article (false negative) is as bad
+ // as showing an irrelevant one (false positive).
+ bestCombinedScore := -1.0
+ bestThreshold := 0.5
+ var bestMetrics ClassificationMetrics
+
+ boolLabels := make([]bool, len(validationLabels))
+ for i, l := range validationLabels {
+ boolLabels[i] = l == 1.0
+ }
+
+ for i := 5; i <= 95; i++ {
+ threshold := float64(i) / 100.0
+ metrics := computeMetrics(scores, boolLabels, threshold)
+
+ sensitivity := metrics.Recall // TPR: TP / (TP + FN)
+ specificity := 0.0
+ if metrics.TrueNegatives+metrics.FalsePositives > 0 {
+ specificity = float64(metrics.TrueNegatives) / float64(metrics.TrueNegatives+metrics.FalsePositives)
+ }
+ youdenJ := sensitivity + specificity - 1.0
+
+ if youdenJ > bestCombinedScore {
+ bestCombinedScore = youdenJ
+ bestThreshold = threshold
+ bestMetrics = metrics
+ }
+ }
+
+ distributions := map[string]any{
+ "positive": posStats,
+ "negative": negStats,
+ "cohens_d": cohensD,
+ "separation_ratio": separationRatio,
+ "best_f1": bestMetrics.F1Score,
+ "best_precision": bestMetrics.Precision,
+ "best_recall": bestMetrics.Recall,
+ }
+
+ return bestThreshold, distributions
+}
+
+// computeScoreStats computes min, max, mean, and std for a slice of scores
+func computeScoreStats(scores []float64) map[string]float64 {
+ if len(scores) == 0 {
+ return map[string]float64{
+ "min": 0.0,
+ "max": 0.0,
+ "mean": 0.0,
+ "std": 0.0,
+ }
+ }
+
+ min, max := scores[0], scores[0]
+ sum := 0.0
+
+ for _, score := range scores {
+ if score < min {
+ min = score
+ }
+ if score > max {
+ max = score
+ }
+ sum += score
+ }
+
+ mean := sum / float64(len(scores))
+
+ // Calculate standard deviation
+ variance := 0.0
+ for _, score := range scores {
+ diff := score - mean
+ variance += diff * diff
+ }
+ variance /= float64(len(scores))
+ std := math.Sqrt(variance)
+
+ return map[string]float64{
+ "min": min,
+ "max": max,
+ "mean": mean,
+ "std": std,
+ }
+}
+
+// computeMetrics calculates classification metrics
+func computeMetrics(scores []float64, labels []bool, threshold float64) ClassificationMetrics {
+ var metrics ClassificationMetrics
+ for i, score := range scores {
+ predicted := score > threshold
+ actual := labels[i]
+
+ if predicted && actual {
+ metrics.TruePositives++
+ } else if predicted && !actual {
+ metrics.FalsePositives++
+ } else if !predicted && actual {
+ metrics.FalseNegatives++
+ } else {
+ metrics.TrueNegatives++
+ }
+ }
+ metrics.Calculate()
+ return metrics
+}
+
+// ============================================================================
+// ╻ ╻┏━╸╻ ┏━┓┏━╸┏━┓┏━┓
+// ┣━┫┣╸ ┃ ┣━┛┣╸ ┣┳┛┗━┓
+// ╹ ╹┗━╸┗━╸╹ ┗━╸╹┗╸┗━┛
+// ============================================================================
+
+// shortURL formats a URL to be human-readable and not too long
+func shortURL(urlStr string) string {
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return urlStr
+ }
+
+ path := u.Path
+ if len(path) > 30 {
+ path = path[:30] + "..."
+ }
+
+ return u.Host + path
+}