diff options
Diffstat (limited to 'cmds/train.go')
| -rw-r--r-- | cmds/train.go | 841 |
1 files changed, 841 insertions, 0 deletions
diff --git a/cmds/train.go b/cmds/train.go new file mode 100644 index 0000000..e7e8915 --- /dev/null +++ b/cmds/train.go @@ -0,0 +1,841 @@ +// Train command learns model from positive examples and RSS feeds. +// Loads positives, fetches RSS feeds as negatives, excludes overlap, +// trains TF-IDF + logistic regression with 1:1 class balancing. +// Outputs model with validation threshold to stdout. +package cmds + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "math" + "math/rand" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/mmcdole/gofeed" + "scholscan/core" +) + +// ============================================================================ +// ┏━╸┏┳┓╺┳┓ ┏━┓┏┓ ┏┓ +// ┃ ┃┃┃ ┃┃ ┃ ┃┣┻┓ ┃ +// ┗━╸╹ ╹╺┻┛ ┗━┛┗━┛┗━┛ +// ============================================================================ + +// Learns model from positive examples and RSS feeds +// Outputs trained model JSON to stdout +type TrainCommand struct { + positivesFile string + rssFeedsFile string + verboseOutput bool + lambda float64 + minDF int + maxDF float64 + ngramMax int +} + +func (c *TrainCommand) Name() string { return "train" } + +func (c *TrainCommand) Init(args []string) error { + fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + fs.Usage = func() { + fmt.Fprint(fs.Output(), `Usage: scholscan train POSITIVES_FILE --rss-feeds RSS_FEEDS_FILE > model.json + +Train a TF-IDF + logistic regression model from positive examples and RSS feeds. + +The training workflow: + 1. Load positive examples from POSITIVES_FILE + 2. Fetch articles from RSS feeds list + 3. Exclude any positive examples from RSS feed articles + 4. Train model with balanced classes + 5. Output trained model to stdout as JSON + +Flags: +`) + fs.PrintDefaults() + fmt.Fprint(fs.Output(), ` +Arguments: + POSITIVES_FILE Path to JSONL file with positive examples (required) + +Example: + scholscan train positives.jsonl --rss-feeds rss_world.txt > model.json +`) + } + + fs.StringVar(&c.rssFeedsFile, "rss-feeds", "", "Path to text file with RSS feed URLs (required)") + fs.BoolVar(&c.verboseOutput, "verbose", false, "Show progress information") + fs.Float64Var(&c.lambda, "lambda", 0.001, "L2 regularization parameter for logistic regression") + fs.IntVar(&c.minDF, "min-df", 2, "Minimum document frequency (absolute count)") + fs.Float64Var(&c.maxDF, "max-df", 0.8, "Maximum document frequency (ratio, 0-1)") + fs.IntVar(&c.ngramMax, "ngram-max", 2, "Maximum n-gram size (e.g., 1=unigrams, 2=unigrams+bigrams)") + + // Check for help flag first + for _, arg := range args { + if arg == "--help" || arg == "-h" { + fs.Usage() + return flag.ErrHelp + } + } + + // Extract positional argument (POSITIVES_FILE) before parsing flags + if len(args) == 0 { + return fmt.Errorf("POSITIVES_FILE argument is required") + } + // The first argument should be the positives file, the rest are flags + c.positivesFile = args[0] + flagArgs := args[1:] + + if err := fs.Parse(flagArgs); err != nil { + return err + } + + if c.rssFeedsFile == "" { + return fmt.Errorf("--rss-feeds flag is required") + } + + // Validate paths are safe (prevent directory traversal) + if strings.Contains(filepath.Clean(c.positivesFile), "..") { + return fmt.Errorf("invalid positives file path: directory traversal not allowed") + } + if strings.Contains(filepath.Clean(c.rssFeedsFile), "..") { + return fmt.Errorf("invalid RSS feeds file path: directory traversal not allowed") + } + + return nil +} + +func (c *TrainCommand) Run(stdin io.Reader, stdout io.Writer) error { + if c.verboseOutput { + log.SetOutput(os.Stderr) + log.Println("Starting training workflow...") + log.Printf("Positives: %s", c.positivesFile) + log.Printf("RSS feeds: %s", c.rssFeedsFile) + } + + if c.verboseOutput { + log.Printf("Loading positives from %s...", c.positivesFile) + } + positives, err := c.loadArticles(c.positivesFile) + if err != nil { + return fmt.Errorf("failed to load positives: %w", err) + } + if c.verboseOutput { + log.Printf("Loaded %d positive examples", len(positives)) + } + + if c.verboseOutput { + log.Printf("Loading RSS feeds from %s...", c.rssFeedsFile) + } + rssURLs, err := c.loadRSSURLs(c.rssFeedsFile) + if err != nil { + return fmt.Errorf("failed to load RSS feeds: %w", err) + } + if c.verboseOutput { + log.Printf("Found %d RSS feeds to fetch", len(rssURLs)) + } + + negatives, err := c.fetchFromRSSFeeds(rssURLs) + if err != nil { + return fmt.Errorf("failed to fetch from RSS feeds: %w", err) + } + if c.verboseOutput { + log.Printf("Fetched %d articles from RSS feeds", len(negatives)) + } + + negatives = c.excludePositives(negatives, positives) + if c.verboseOutput { + log.Printf("After exclusion: %d negative examples", len(negatives)) + } + + if len(positives) == 0 || len(negatives) == 0 { + return fmt.Errorf("need both positive (%d) and negative (%d) examples for training", len(positives), len(negatives)) + } + + if c.verboseOutput { + log.Println("Training model...") + } + model, err := c.trainModel(positives, negatives) + if err != nil { + return fmt.Errorf("failed to train model: %w", err) + } + + // Output model + encoder := json.NewEncoder(stdout) + encoder.SetIndent("", " ") + if err := encoder.Encode(model); err != nil { + return fmt.Errorf("failed to write model: %w", err) + } + + return nil +} + +// ============================================================================ +// ╺┳┓┏━┓╺┳╸┏━┓ ╻ ┏━┓┏━┓╺┳┓╻┏┓╻┏━╸ +// ┃┃┣━┫ ┃ ┣━┫ ┃ ┃ ┃┣━┫ ┃┃┃┃┗┫┃╺┓ +// ╺┻┛╹ ╹ ╹ ╹ ╹ ┗━╸┗━┛╹ ╹╺┻┛╹╹ ╹┗━┛ +// ============================================================================ + +func (c *TrainCommand) loadArticles(filename string) ([]*core.Article, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + var articles []*core.Article + decoder := json.NewDecoder(file) + lineCount := 0 + for { + var article core.Article + if err := decoder.Decode(&article); err != nil { + if err == io.EOF { + break + } + // Skip malformed json lines, don't fail on bad input. + lineCount++ + continue + } + articles = append(articles, &article) + lineCount++ + if lineCount%500 == 0 && c.verboseOutput { + log.Printf(" Loaded %d articles so far", len(articles)) + } + } + return articles, nil +} + +// loadRSSURLs loads RSS feed URLs from a text file +func (c *TrainCommand) loadRSSURLs(filename string) ([]string, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + var urls []string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" && !strings.HasPrefix(line, "#") { + urls = append(urls, line) + } + } + return urls, scanner.Err() +} + +// fetchFromRSSFeeds fetches articles from multiple RSS feeds in parallel +func (c *TrainCommand) fetchFromRSSFeeds(rssURLs []string) ([]*core.Article, error) { + client := core.DefaultHTTPClient + type result struct { + url string + articles []*core.Article + err error + } + resultChan := make(chan result, len(rssURLs)) + + for _, rssURL := range rssURLs { + go func(url string) { + articles, err := c.fetchRSSFeed(client, url) + resultChan <- result{url: url, articles: articles, err: err} + }(rssURL) + } + + var allArticles []*core.Article + for i := 0; i < len(rssURLs); i++ { + res := <-resultChan + if res.err != nil { + if c.verboseOutput { + log.Printf("%s: failed to fetch", shortURL(res.url)) + } + } else { + if c.verboseOutput { + log.Printf("%s: %d articles", shortURL(res.url), len(res.articles)) + } + allArticles = append(allArticles, res.articles...) + } + } + + return allArticles, nil +} + +// ParseRSSFeed parses an RSS/Atom feed from the provided body into a slice of Articles. +func ParseRSSFeed(body []byte, baseURL string) ([]*core.Article, error) { + fp := gofeed.NewParser() + feed, err := fp.Parse(bytes.NewReader(body)) + if err != nil { + return nil, err + } + + var articles []*core.Article + for _, item := range feed.Items { + // Prefer explicit content; fall back to description. + content := strings.TrimSpace(item.Content) + if content == "" { + content = item.Description + } + // Also check custom content field (for <content> tags in RSS) + if content == "" && item.Custom != nil { + if c, ok := item.Custom["content"]; ok && c != "" { + content = c + } + } + + // Clean and limit content length + content = core.CleanFeedContent(content) + + articles = append(articles, &core.Article{ + URL: item.Link, + Title: item.Title, + Content: content, + }) + } + return articles, nil +} + +// fetchRSSFeed fetches and parses a single RSS feed +func (c *TrainCommand) fetchRSSFeed(client *http.Client, rssURL string) ([]*core.Article, error) { + var body []byte + var err error + + // Handle file:// URLs locally + if strings.HasPrefix(rssURL, "file://") { + // Remove file:// prefix + filePath := strings.TrimPrefix(rssURL, "file://") + body, err = os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("error reading file %s: %w", filePath, err) + } + } else { + // Handle HTTP/HTTPS URLs normally + req, err := http.NewRequest("GET", rssURL, nil) + if err != nil { + return nil, fmt.Errorf("error building request: %w", err) + } + req.Header.Set("User-Agent", core.PoliteUserAgent) + + // Make request with retry logic + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + resp, err := core.DoRequestWithRetry(ctx, client, req) + if err != nil { + return nil, fmt.Errorf("error fetching %s: %w", rssURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, rssURL) + } + + // Read response body + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response from %s: %w", rssURL, err) + } + } + + // Parse RSS/Atom feed + return ParseRSSFeed(body, rssURL) +} + +// ============================================================================ +// ╺┳┓┏━┓╺┳╸┏━┓ ┏━┓┏━┓┏━╸┏━┓ +// ┃┃┣━┫ ┃ ┣━┫ ┣━┛┣┳┛┣╸ ┣━┛ +// ╺┻┛╹ ╹ ╹ ╹ ╹ ╹ ╹┗╸┗━╸╹ +// ============================================================================ + +func (c *TrainCommand) excludePositives(negatives, positives []*core.Article) []*core.Article { + // Build set of positive URLs for O(1) lookup + positiveURLs := make(map[string]bool) + for _, pos := range positives { + positiveURLs[pos.URL] = true + } + + // Filter out positives + var filtered []*core.Article + for _, neg := range negatives { + if !positiveURLs[neg.URL] { + filtered = append(filtered, neg) + } + } + + return filtered +} + +// splitTrainingData performs a deterministic 80/20 split (seed=42). +// Deterministic ensures reproducible model training across runs. +func (c *TrainCommand) splitTrainingData(documents []string, labels []float64) ( + trainDocs, valDocs []string, + trainLabels, valLabels []float64, +) { + const validationSplitRatio = 0.2 + const splitSeed = 42 + + if len(documents) < 3 { + // Not enough data to split, use all for training. + // A split requires at least 2 training documents to avoid MaxDF issues + // and at least 1 validation document. + return documents, nil, labels, nil + } + + // Create a reproducible random source and shuffle indices. + rng := rand.New(rand.NewSource(splitSeed)) + indices := make([]int, len(documents)) + for i := range indices { + indices[i] = i + } + rng.Shuffle(len(indices), func(i, j int) { + indices[i], indices[j] = indices[j], indices[i] + }) + + splitIndex := int(float64(len(documents)) * (1.0 - validationSplitRatio)) + trainIndices := indices[:splitIndex] + valIndices := indices[splitIndex:] + + trainDocs = make([]string, len(trainIndices)) + trainLabels = make([]float64, len(trainIndices)) + for i, idx := range trainIndices { + trainDocs[i] = documents[idx] + trainLabels[i] = labels[idx] + } + + valDocs = make([]string, len(valIndices)) + valLabels = make([]float64, len(valIndices)) + for i, idx := range valIndices { + valDocs[i] = documents[idx] + valLabels[i] = labels[idx] + } + + return trainDocs, valDocs, trainLabels, valLabels +} + +// Downsample majority class to 1:1 ratio AFTER vectorizer.Fit() to preserve IDF values. +func (c *TrainCommand) downsampleToBalance(docs []string, labels []float64) ([]string, []float64) { + // Count positives and negatives + var posDocs, negDocs []string + var posLabels, negLabels []float64 + + for i, label := range labels { + if label == 1.0 { + posDocs = append(posDocs, docs[i]) + posLabels = append(posLabels, label) + } else { + negDocs = append(negDocs, docs[i]) + negLabels = append(negLabels, label) + } + } + + // If already balanced, return as-is + if len(posDocs) == len(negDocs) { + return docs, labels + } + + // Determine which class is majority + var majorityDocs, minorityDocs []string + var majorityLabels, minorityLabels []float64 + + if len(negDocs) > len(posDocs) { + // Negatives are majority + majorityDocs, minorityDocs = negDocs, posDocs + majorityLabels, minorityLabels = negLabels, posLabels + } else { + // Positives are majority (unlikely but handle) + majorityDocs, minorityDocs = posDocs, negDocs + majorityLabels, minorityLabels = posLabels, negLabels + } + + // Downsample majority to match minority size + minoritySize := len(minorityDocs) + rng := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility + + // Create random indices for downsampling + indices := make([]int, len(majorityDocs)) + for i := range indices { + indices[i] = i + } + rng.Shuffle(len(indices), func(i, j int) { + indices[i], indices[j] = indices[j], indices[i] + }) + + // Select downsampled majority + downsampledDocs := make([]string, 0, minoritySize*2) + downsampledLabels := make([]float64, 0, minoritySize*2) + + // Add all minority samples + downsampledDocs = append(downsampledDocs, minorityDocs...) + downsampledLabels = append(downsampledLabels, minorityLabels...) + + // Add downsampled majority + for i := 0; i < minoritySize; i++ { + idx := indices[i] + downsampledDocs = append(downsampledDocs, majorityDocs[idx]) + downsampledLabels = append(downsampledLabels, majorityLabels[idx]) + } + + return downsampledDocs, downsampledLabels +} + +// ============================================================================ +// ╺┳╸┏━┓┏━┓╻┏┓╻ ┏┳┓┏━┓╺┳┓┏━╸╻ +// ┃ ┣┳┛┣━┫┃┃┗┫ ┃┃┃┃ ┃ ┃┃┣╸ ┃ +// ╹ ╹┗╸╹ ╹╹╹ ╹ ╹ ╹┗━┛╺┻┛┗━╸┗━╸ +// ============================================================================ + +// trainModel trains a TF-IDF + logistic regression model +func (c *TrainCommand) trainModel(positives, negatives []*core.Article) (*core.ModelEnvelope, error) { + // Combine datasets and create labels + var documents []string + var labels []float64 + + // Process positives + for _, article := range positives { + // Skip articles with titles that are too short + if len(article.Title) < 15 { + continue + } + documents = append(documents, article.Title) + labels = append(labels, 1.0) + } + + // Process negatives + for _, article := range negatives { + // Skip articles with titles that are too short + if len(article.Title) < 15 { + continue + } + documents = append(documents, article.Title) + labels = append(labels, 0.0) + } + + // Use parameters from CLI flags (with defaults matching Julia implementation) + const vocabCap = 50000 + + // Deterministic 80/20 split for train/validation + trainDocs, valDocs, trainLabels, valLabels := c.splitTrainingData(documents, labels) + + // Create TF-IDF vectorizer with the specified parameters + vectorizer := &core.TFIDFVectorizer{ + NgramMin: 1, + NgramMax: c.ngramMax, + MinDF: c.minDF, + MaxDF: c.maxDF, + VocabCap: vocabCap, + Vocabulary: make(map[string]float64), + } + // Fit vectorizer on UNBALANCED training data to match Julia implementation + // This preserves document frequencies properly + vectorizer.Fit(trainDocs) + + // Downsample negatives to 1:1 ratio AFTER fitting (match Julia approach) + balancedTrainDocs, balancedTrainLabels := c.downsampleToBalance(trainDocs, trainLabels) + + // Transform both training and validation sets + trainVectors := vectorizer.Transform(balancedTrainDocs) + valVectors := vectorizer.Transform(valDocs) + + // Use uniform class weights since we've balanced the dataset + classWeights := map[float64]float64{ + 1.0: 1.0, + 0.0: 1.0, + } + + // Train logistic regression with the specified lambda parameter + lr := &core.LogisticRegression{ + LearningRate: 0.5, + Lambda: c.lambda, + Iterations: 500, + Tolerance: 0.000001, + } + lr.Validate() + weights, err := lr.Fit(trainVectors, balancedTrainLabels, classWeights) + if err != nil { + return nil, fmt.Errorf("failed to train logistic regression model: %w", err) + } + + // Find the best threshold on the validation set + recommendedThreshold, scoreDistributions := c.findBestThreshold(valVectors, valLabels, weights) + + // Count classes for metadata + var posCount, negCount float64 + for _, label := range labels { + if label == 1.0 { + posCount++ + } else { + negCount++ + } + } + + // Create model envelope + model := &core.ModelEnvelope{ + Algorithm: "tfidf-go", + Impl: "go", + Version: "1", + CreatedAt: time.Now().UTC(), + Meta: map[string]any{ + "positives": len(positives), + "negatives": len(negatives), + "class_counts": map[string]int{ + "pos": int(posCount), + "neg": int(negCount), + }, + "vectorizer_params": map[string]any{ + "ngram_min": vectorizer.NgramMin, + "ngram_max": vectorizer.NgramMax, + "min_df": vectorizer.MinDF, + "max_df": vectorizer.MaxDF, + "vocab_cap": vectorizer.VocabCap, + }, + "model_params": map[string]any{ + "learning_rate": lr.LearningRate, + "lambda": lr.Lambda, + "iterations": lr.Iterations, + "tolerance": lr.Tolerance, + }, + "recommended_threshold": recommendedThreshold, + "score_distributions": scoreDistributions, + }, + Vectorizer: vectorizer.Vocabulary, + OrderedVocab: vectorizer.OrderedVocab, + Weights: weights, + } + + return model, nil +} + +// ============================================================================ +// ┏┳┓┏━╸╺┳╸┏━┓╻┏━╸┏━┓ +// ┃┃┃┣╸ ┃ ┣┳┛┃┃ ┗━┓ +// ╹ ╹┗━╸ ╹ ╹┗╸╹┗━╸┗━┛ +// ============================================================================ + +// ClassificationMetrics holds the evaluation metrics +type ClassificationMetrics struct { + TruePositives int + TrueNegatives int + FalsePositives int + FalseNegatives int + Accuracy float64 + Precision float64 + Recall float64 + F1Score float64 +} + +// Calculate computes the metrics from raw counts +func (m *ClassificationMetrics) Calculate() { + total := m.TruePositives + m.TrueNegatives + m.FalsePositives + m.FalseNegatives + + if total > 0 { + m.Accuracy = float64(m.TruePositives+m.TrueNegatives) / float64(total) + } + + if m.TruePositives+m.FalsePositives > 0 { + m.Precision = float64(m.TruePositives) / float64(m.TruePositives+m.FalsePositives) + } + + if m.TruePositives+m.FalseNegatives > 0 { + m.Recall = float64(m.TruePositives) / float64(m.TruePositives+m.FalseNegatives) + } + + if m.Precision+m.Recall > 0 { + m.F1Score = 2 * (m.Precision * m.Recall) / (m.Precision + m.Recall) + } +} + +// findBestThreshold sweeps a range of thresholds on a validation set to find +// the one that maximizes combined F1 + separation score. +func (c *TrainCommand) findBestThreshold( + validationVectors [][]float64, + validationLabels []float64, + weights []float64, +) (float64, map[string]any) { + if len(validationVectors) == 0 { + return 0.5, nil // Default if no validation data + } + + scores := make([]float64, len(validationVectors)) + for i, vector := range validationVectors { + score, err := core.PredictScore(vector, weights) + if err != nil { + // This should not happen with valid data, but as a fallback: + return 0.5, nil + } + scores[i] = score + } + + // Collect score distributions by label + var posScores, negScores []float64 + for i, score := range scores { + if validationLabels[i] == 1.0 { + posScores = append(posScores, score) + } else { + negScores = append(negScores, score) + } + } + + // Compute stats for each class + posStats := computeScoreStats(posScores) + negStats := computeScoreStats(negScores) + + // Calculate Cohen's d (effect size) to measure class separation in the learned space + posMean := posStats["mean"] + negMean := negStats["mean"] + posStd := posStats["std"] + negStd := negStats["std"] + + var cohensD float64 + if posStd > 0 && negStd > 0 { + pooledStd := math.Sqrt((posStd*posStd + negStd*negStd) / 2) + cohensD = math.Abs(posMean-negMean) / pooledStd + } + + // Calculate separation ratio to understand how much the classes overlap on the score scale + totalRange := math.Max(posStats["max"], negStats["max"]) - math.Min(posStats["min"], negStats["min"]) + overlapStart := math.Max(posStats["min"], negStats["min"]) + overlapEnd := math.Min(posStats["max"], negStats["max"]) + overlapRange := math.Max(0, overlapEnd-overlapStart) + separationRatio := 0.0 + if totalRange > 0 { + separationRatio = (totalRange - overlapRange) / totalRange + } + + // Find threshold that balances false positives and false negatives using Youden's J. + // This metric (Sensitivity + Specificity - 1) equally weights both false positive + // and false negative rates. Why not F1? F1 biases toward precision when classes + // are imbalanced; a validation set of 10 positives and 1000 negatives would push + // the threshold too high. Youden's J treats both types of error equally, which + // better reflects real use: missing a relevant article (false negative) is as bad + // as showing an irrelevant one (false positive). + bestCombinedScore := -1.0 + bestThreshold := 0.5 + var bestMetrics ClassificationMetrics + + boolLabels := make([]bool, len(validationLabels)) + for i, l := range validationLabels { + boolLabels[i] = l == 1.0 + } + + for i := 5; i <= 95; i++ { + threshold := float64(i) / 100.0 + metrics := computeMetrics(scores, boolLabels, threshold) + + sensitivity := metrics.Recall // TPR: TP / (TP + FN) + specificity := 0.0 + if metrics.TrueNegatives+metrics.FalsePositives > 0 { + specificity = float64(metrics.TrueNegatives) / float64(metrics.TrueNegatives+metrics.FalsePositives) + } + youdenJ := sensitivity + specificity - 1.0 + + if youdenJ > bestCombinedScore { + bestCombinedScore = youdenJ + bestThreshold = threshold + bestMetrics = metrics + } + } + + distributions := map[string]any{ + "positive": posStats, + "negative": negStats, + "cohens_d": cohensD, + "separation_ratio": separationRatio, + "best_f1": bestMetrics.F1Score, + "best_precision": bestMetrics.Precision, + "best_recall": bestMetrics.Recall, + } + + return bestThreshold, distributions +} + +// computeScoreStats computes min, max, mean, and std for a slice of scores +func computeScoreStats(scores []float64) map[string]float64 { + if len(scores) == 0 { + return map[string]float64{ + "min": 0.0, + "max": 0.0, + "mean": 0.0, + "std": 0.0, + } + } + + min, max := scores[0], scores[0] + sum := 0.0 + + for _, score := range scores { + if score < min { + min = score + } + if score > max { + max = score + } + sum += score + } + + mean := sum / float64(len(scores)) + + // Calculate standard deviation + variance := 0.0 + for _, score := range scores { + diff := score - mean + variance += diff * diff + } + variance /= float64(len(scores)) + std := math.Sqrt(variance) + + return map[string]float64{ + "min": min, + "max": max, + "mean": mean, + "std": std, + } +} + +// computeMetrics calculates classification metrics +func computeMetrics(scores []float64, labels []bool, threshold float64) ClassificationMetrics { + var metrics ClassificationMetrics + for i, score := range scores { + predicted := score > threshold + actual := labels[i] + + if predicted && actual { + metrics.TruePositives++ + } else if predicted && !actual { + metrics.FalsePositives++ + } else if !predicted && actual { + metrics.FalseNegatives++ + } else { + metrics.TrueNegatives++ + } + } + metrics.Calculate() + return metrics +} + +// ============================================================================ +// ╻ ╻┏━╸╻ ┏━┓┏━╸┏━┓┏━┓ +// ┣━┫┣╸ ┃ ┣━┛┣╸ ┣┳┛┗━┓ +// ╹ ╹┗━╸┗━╸╹ ┗━╸╹┗╸┗━┛ +// ============================================================================ + +// shortURL formats a URL to be human-readable and not too long +func shortURL(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + return urlStr + } + + path := u.Path + if len(path) > 30 { + path = path[:30] + "..." + } + + return u.Host + path +} |
