// Train command learns model from positive examples and RSS feeds. // Loads positives, fetches RSS feeds as negatives, excludes overlap, // trains TF-IDF + logistic regression with 1:1 class balancing. // Outputs model with validation threshold to stdout. package cmds import ( "bufio" "bytes" "context" "encoding/json" "flag" "fmt" "io" "log" "math" "math/rand" "net/http" "net/url" "os" "path/filepath" "strings" "time" "github.com/mmcdole/gofeed" "scholscan/core" ) // ============================================================================ // ┏━╸┏┳┓╺┳┓ ┏━┓┏┓ ┏┓ // ┃ ┃┃┃ ┃┃ ┃ ┃┣┻┓ ┃ // ┗━╸╹ ╹╺┻┛ ┗━┛┗━┛┗━┛ // ============================================================================ // Learns model from positive examples and RSS feeds // Outputs trained model JSON to stdout type TrainCommand struct { positivesFile string rssFeedsFile string verboseOutput bool lambda float64 minDF int maxDF float64 ngramMax int } func (c *TrainCommand) Name() string { return "train" } func (c *TrainCommand) Init(args []string) error { fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError) fs.Usage = func() { fmt.Fprint(fs.Output(), `Usage: scholscan train POSITIVES_FILE --rss-feeds RSS_FEEDS_FILE > model.json Train a TF-IDF + logistic regression model from positive examples and RSS feeds. The training workflow: 1. Load positive examples from POSITIVES_FILE 2. Fetch articles from RSS feeds list 3. Exclude any positive examples from RSS feed articles 4. Train model with balanced classes 5. Output trained model to stdout as JSON Flags: `) fs.PrintDefaults() fmt.Fprint(fs.Output(), ` Arguments: POSITIVES_FILE Path to JSONL file with positive examples (required) Example: scholscan train positives.jsonl --rss-feeds rss_world.txt > model.json `) } fs.StringVar(&c.rssFeedsFile, "rss-feeds", "", "Path to text file with RSS feed URLs (required)") fs.BoolVar(&c.verboseOutput, "verbose", false, "Show progress information") fs.Float64Var(&c.lambda, "lambda", 0.001, "L2 regularization parameter for logistic regression") fs.IntVar(&c.minDF, "min-df", 2, "Minimum document frequency (absolute count)") fs.Float64Var(&c.maxDF, "max-df", 0.8, "Maximum document frequency (ratio, 0-1)") fs.IntVar(&c.ngramMax, "ngram-max", 2, "Maximum n-gram size (e.g., 1=unigrams, 2=unigrams+bigrams)") // Check for help flag first for _, arg := range args { if arg == "--help" || arg == "-h" { fs.Usage() return flag.ErrHelp } } // Extract positional argument (POSITIVES_FILE) before parsing flags if len(args) == 0 { return fmt.Errorf("POSITIVES_FILE argument is required") } // The first argument should be the positives file, the rest are flags c.positivesFile = args[0] flagArgs := args[1:] if err := fs.Parse(flagArgs); err != nil { return err } if c.rssFeedsFile == "" { return fmt.Errorf("--rss-feeds flag is required") } // Validate paths are safe (prevent directory traversal) if strings.Contains(filepath.Clean(c.positivesFile), "..") { return fmt.Errorf("invalid positives file path: directory traversal not allowed") } if strings.Contains(filepath.Clean(c.rssFeedsFile), "..") { return fmt.Errorf("invalid RSS feeds file path: directory traversal not allowed") } return nil } func (c *TrainCommand) Run(stdin io.Reader, stdout io.Writer) error { if c.verboseOutput { log.SetOutput(os.Stderr) log.Println("Starting training workflow...") log.Printf("Positives: %s", c.positivesFile) log.Printf("RSS feeds: %s", c.rssFeedsFile) } if c.verboseOutput { log.Printf("Loading positives from %s...", c.positivesFile) } positives, err := c.loadArticles(c.positivesFile) if err != nil { return fmt.Errorf("failed to load positives: %w", err) } if c.verboseOutput { log.Printf("Loaded %d positive examples", len(positives)) } if c.verboseOutput { log.Printf("Loading RSS feeds from %s...", c.rssFeedsFile) } rssURLs, err := c.loadRSSURLs(c.rssFeedsFile) if err != nil { return fmt.Errorf("failed to load RSS feeds: %w", err) } if c.verboseOutput { log.Printf("Found %d RSS feeds to fetch", len(rssURLs)) } negatives, err := c.fetchFromRSSFeeds(rssURLs) if err != nil { return fmt.Errorf("failed to fetch from RSS feeds: %w", err) } if c.verboseOutput { log.Printf("Fetched %d articles from RSS feeds", len(negatives)) } negatives = c.excludePositives(negatives, positives) if c.verboseOutput { log.Printf("After exclusion: %d negative examples", len(negatives)) } if len(positives) == 0 || len(negatives) == 0 { return fmt.Errorf("need both positive (%d) and negative (%d) examples for training", len(positives), len(negatives)) } if c.verboseOutput { log.Println("Training model...") } model, err := c.trainModel(positives, negatives) if err != nil { return fmt.Errorf("failed to train model: %w", err) } // Output model encoder := json.NewEncoder(stdout) encoder.SetIndent("", " ") if err := encoder.Encode(model); err != nil { return fmt.Errorf("failed to write model: %w", err) } return nil } // ============================================================================ // ╺┳┓┏━┓╺┳╸┏━┓ ╻ ┏━┓┏━┓╺┳┓╻┏┓╻┏━╸ // ┃┃┣━┫ ┃ ┣━┫ ┃ ┃ ┃┣━┫ ┃┃┃┃┗┫┃╺┓ // ╺┻┛╹ ╹ ╹ ╹ ╹ ┗━╸┗━┛╹ ╹╺┻┛╹╹ ╹┗━┛ // ============================================================================ func (c *TrainCommand) loadArticles(filename string) ([]*core.Article, error) { file, err := os.Open(filename) if err != nil { return nil, err } defer file.Close() var articles []*core.Article decoder := json.NewDecoder(file) lineCount := 0 for { var article core.Article if err := decoder.Decode(&article); err != nil { if err == io.EOF { break } // Skip malformed json lines, don't fail on bad input. lineCount++ continue } articles = append(articles, &article) lineCount++ if lineCount%500 == 0 && c.verboseOutput { log.Printf(" Loaded %d articles so far", len(articles)) } } return articles, nil } // loadRSSURLs loads RSS feed URLs from a text file func (c *TrainCommand) loadRSSURLs(filename string) ([]string, error) { file, err := os.Open(filename) if err != nil { return nil, err } defer file.Close() var urls []string scanner := bufio.NewScanner(file) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line != "" && !strings.HasPrefix(line, "#") { urls = append(urls, line) } } return urls, scanner.Err() } // fetchFromRSSFeeds fetches articles from multiple RSS feeds in parallel func (c *TrainCommand) fetchFromRSSFeeds(rssURLs []string) ([]*core.Article, error) { client := core.DefaultHTTPClient type result struct { url string articles []*core.Article err error } resultChan := make(chan result, len(rssURLs)) for _, rssURL := range rssURLs { go func(url string) { articles, err := c.fetchRSSFeed(client, url) resultChan <- result{url: url, articles: articles, err: err} }(rssURL) } var allArticles []*core.Article for i := 0; i < len(rssURLs); i++ { res := <-resultChan if res.err != nil { if c.verboseOutput { log.Printf("%s: failed to fetch", shortURL(res.url)) } } else { if c.verboseOutput { log.Printf("%s: %d articles", shortURL(res.url), len(res.articles)) } allArticles = append(allArticles, res.articles...) } } return allArticles, nil } // ParseRSSFeed parses an RSS/Atom feed from the provided body into a slice of Articles. func ParseRSSFeed(body []byte, baseURL string) ([]*core.Article, error) { fp := gofeed.NewParser() feed, err := fp.Parse(bytes.NewReader(body)) if err != nil { return nil, err } var articles []*core.Article for _, item := range feed.Items { // Prefer explicit content; fall back to description. content := strings.TrimSpace(item.Content) if content == "" { content = item.Description } // Also check custom content field (for tags in RSS) if content == "" && item.Custom != nil { if c, ok := item.Custom["content"]; ok && c != "" { content = c } } // Clean and limit content length content = core.CleanFeedContent(content) articles = append(articles, &core.Article{ URL: item.Link, Title: item.Title, Content: content, }) } return articles, nil } // fetchRSSFeed fetches and parses a single RSS feed func (c *TrainCommand) fetchRSSFeed(client *http.Client, rssURL string) ([]*core.Article, error) { var body []byte var err error // Handle file:// URLs locally if strings.HasPrefix(rssURL, "file://") { // Remove file:// prefix filePath := strings.TrimPrefix(rssURL, "file://") body, err = os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("error reading file %s: %w", filePath, err) } } else { // Handle HTTP/HTTPS URLs normally req, err := http.NewRequest("GET", rssURL, nil) if err != nil { return nil, fmt.Errorf("error building request: %w", err) } req.Header.Set("User-Agent", core.PoliteUserAgent) // Make request with retry logic ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() resp, err := core.DoRequestWithRetry(ctx, client, req) if err != nil { return nil, fmt.Errorf("error fetching %s: %w", rssURL, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, rssURL) } // Read response body body, err = io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("error reading response from %s: %w", rssURL, err) } } // Parse RSS/Atom feed return ParseRSSFeed(body, rssURL) } // ============================================================================ // ╺┳┓┏━┓╺┳╸┏━┓ ┏━┓┏━┓┏━╸┏━┓ // ┃┃┣━┫ ┃ ┣━┫ ┣━┛┣┳┛┣╸ ┣━┛ // ╺┻┛╹ ╹ ╹ ╹ ╹ ╹ ╹┗╸┗━╸╹ // ============================================================================ func (c *TrainCommand) excludePositives(negatives, positives []*core.Article) []*core.Article { // Build set of positive URLs for O(1) lookup positiveURLs := make(map[string]bool) for _, pos := range positives { positiveURLs[pos.URL] = true } // Filter out positives var filtered []*core.Article for _, neg := range negatives { if !positiveURLs[neg.URL] { filtered = append(filtered, neg) } } return filtered } // splitTrainingData performs a deterministic 80/20 split (seed=42). // Deterministic ensures reproducible model training across runs. func (c *TrainCommand) splitTrainingData(documents []string, labels []float64) ( trainDocs, valDocs []string, trainLabels, valLabels []float64, ) { const validationSplitRatio = 0.2 const splitSeed = 42 if len(documents) < 3 { // Not enough data to split, use all for training. // A split requires at least 2 training documents to avoid MaxDF issues // and at least 1 validation document. return documents, nil, labels, nil } // Create a reproducible random source and shuffle indices. rng := rand.New(rand.NewSource(splitSeed)) indices := make([]int, len(documents)) for i := range indices { indices[i] = i } rng.Shuffle(len(indices), func(i, j int) { indices[i], indices[j] = indices[j], indices[i] }) splitIndex := int(float64(len(documents)) * (1.0 - validationSplitRatio)) trainIndices := indices[:splitIndex] valIndices := indices[splitIndex:] trainDocs = make([]string, len(trainIndices)) trainLabels = make([]float64, len(trainIndices)) for i, idx := range trainIndices { trainDocs[i] = documents[idx] trainLabels[i] = labels[idx] } valDocs = make([]string, len(valIndices)) valLabels = make([]float64, len(valIndices)) for i, idx := range valIndices { valDocs[i] = documents[idx] valLabels[i] = labels[idx] } return trainDocs, valDocs, trainLabels, valLabels } // Downsample majority class to 1:1 ratio AFTER vectorizer.Fit() to preserve IDF values. func (c *TrainCommand) downsampleToBalance(docs []string, labels []float64) ([]string, []float64) { // Count positives and negatives var posDocs, negDocs []string var posLabels, negLabels []float64 for i, label := range labels { if label == 1.0 { posDocs = append(posDocs, docs[i]) posLabels = append(posLabels, label) } else { negDocs = append(negDocs, docs[i]) negLabels = append(negLabels, label) } } // If already balanced, return as-is if len(posDocs) == len(negDocs) { return docs, labels } // Determine which class is majority var majorityDocs, minorityDocs []string var majorityLabels, minorityLabels []float64 if len(negDocs) > len(posDocs) { // Negatives are majority majorityDocs, minorityDocs = negDocs, posDocs majorityLabels, minorityLabels = negLabels, posLabels } else { // Positives are majority (unlikely but handle) majorityDocs, minorityDocs = posDocs, negDocs majorityLabels, minorityLabels = posLabels, negLabels } // Downsample majority to match minority size minoritySize := len(minorityDocs) rng := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility // Create random indices for downsampling indices := make([]int, len(majorityDocs)) for i := range indices { indices[i] = i } rng.Shuffle(len(indices), func(i, j int) { indices[i], indices[j] = indices[j], indices[i] }) // Select downsampled majority downsampledDocs := make([]string, 0, minoritySize*2) downsampledLabels := make([]float64, 0, minoritySize*2) // Add all minority samples downsampledDocs = append(downsampledDocs, minorityDocs...) downsampledLabels = append(downsampledLabels, minorityLabels...) // Add downsampled majority for i := 0; i < minoritySize; i++ { idx := indices[i] downsampledDocs = append(downsampledDocs, majorityDocs[idx]) downsampledLabels = append(downsampledLabels, majorityLabels[idx]) } return downsampledDocs, downsampledLabels } // ============================================================================ // ╺┳╸┏━┓┏━┓╻┏┓╻ ┏┳┓┏━┓╺┳┓┏━╸╻ // ┃ ┣┳┛┣━┫┃┃┗┫ ┃┃┃┃ ┃ ┃┃┣╸ ┃ // ╹ ╹┗╸╹ ╹╹╹ ╹ ╹ ╹┗━┛╺┻┛┗━╸┗━╸ // ============================================================================ // trainModel trains a TF-IDF + logistic regression model func (c *TrainCommand) trainModel(positives, negatives []*core.Article) (*core.ModelEnvelope, error) { // Combine datasets and create labels var documents []string var labels []float64 // Process positives for _, article := range positives { // Skip articles with titles that are too short if len(article.Title) < 15 { continue } documents = append(documents, article.Title) labels = append(labels, 1.0) } // Process negatives for _, article := range negatives { // Skip articles with titles that are too short if len(article.Title) < 15 { continue } documents = append(documents, article.Title) labels = append(labels, 0.0) } // Use parameters from CLI flags (with defaults matching Julia implementation) const vocabCap = 50000 // Deterministic 80/20 split for train/validation trainDocs, valDocs, trainLabels, valLabels := c.splitTrainingData(documents, labels) // Create TF-IDF vectorizer with the specified parameters vectorizer := &core.TFIDFVectorizer{ NgramMin: 1, NgramMax: c.ngramMax, MinDF: c.minDF, MaxDF: c.maxDF, VocabCap: vocabCap, Vocabulary: make(map[string]float64), } // Fit vectorizer on UNBALANCED training data to match Julia implementation // This preserves document frequencies properly vectorizer.Fit(trainDocs) // Downsample negatives to 1:1 ratio AFTER fitting (match Julia approach) balancedTrainDocs, balancedTrainLabels := c.downsampleToBalance(trainDocs, trainLabels) // Transform both training and validation sets trainVectors := vectorizer.Transform(balancedTrainDocs) valVectors := vectorizer.Transform(valDocs) // Use uniform class weights since we've balanced the dataset classWeights := map[float64]float64{ 1.0: 1.0, 0.0: 1.0, } // Train logistic regression with the specified lambda parameter lr := &core.LogisticRegression{ LearningRate: 0.5, Lambda: c.lambda, Iterations: 500, Tolerance: 0.000001, } lr.Validate() weights, err := lr.Fit(trainVectors, balancedTrainLabels, classWeights) if err != nil { return nil, fmt.Errorf("failed to train logistic regression model: %w", err) } // Find the best threshold on the validation set recommendedThreshold, scoreDistributions := c.findBestThreshold(valVectors, valLabels, weights) // Count classes for metadata var posCount, negCount float64 for _, label := range labels { if label == 1.0 { posCount++ } else { negCount++ } } // Create model envelope model := &core.ModelEnvelope{ Algorithm: "tfidf-go", Impl: "go", Version: "1", CreatedAt: time.Now().UTC(), Meta: map[string]any{ "positives": len(positives), "negatives": len(negatives), "class_counts": map[string]int{ "pos": int(posCount), "neg": int(negCount), }, "vectorizer_params": map[string]any{ "ngram_min": vectorizer.NgramMin, "ngram_max": vectorizer.NgramMax, "min_df": vectorizer.MinDF, "max_df": vectorizer.MaxDF, "vocab_cap": vectorizer.VocabCap, }, "model_params": map[string]any{ "learning_rate": lr.LearningRate, "lambda": lr.Lambda, "iterations": lr.Iterations, "tolerance": lr.Tolerance, }, "recommended_threshold": recommendedThreshold, "score_distributions": scoreDistributions, }, Vectorizer: vectorizer.Vocabulary, OrderedVocab: vectorizer.OrderedVocab, Weights: weights, } return model, nil } // ============================================================================ // ┏┳┓┏━╸╺┳╸┏━┓╻┏━╸┏━┓ // ┃┃┃┣╸ ┃ ┣┳┛┃┃ ┗━┓ // ╹ ╹┗━╸ ╹ ╹┗╸╹┗━╸┗━┛ // ============================================================================ // ClassificationMetrics holds the evaluation metrics type ClassificationMetrics struct { TruePositives int TrueNegatives int FalsePositives int FalseNegatives int Accuracy float64 Precision float64 Recall float64 F1Score float64 } // Calculate computes the metrics from raw counts func (m *ClassificationMetrics) Calculate() { total := m.TruePositives + m.TrueNegatives + m.FalsePositives + m.FalseNegatives if total > 0 { m.Accuracy = float64(m.TruePositives+m.TrueNegatives) / float64(total) } if m.TruePositives+m.FalsePositives > 0 { m.Precision = float64(m.TruePositives) / float64(m.TruePositives+m.FalsePositives) } if m.TruePositives+m.FalseNegatives > 0 { m.Recall = float64(m.TruePositives) / float64(m.TruePositives+m.FalseNegatives) } if m.Precision+m.Recall > 0 { m.F1Score = 2 * (m.Precision * m.Recall) / (m.Precision + m.Recall) } } // findBestThreshold sweeps a range of thresholds on a validation set to find // the one that maximizes combined F1 + separation score. func (c *TrainCommand) findBestThreshold( validationVectors [][]float64, validationLabels []float64, weights []float64, ) (float64, map[string]any) { if len(validationVectors) == 0 { return 0.5, nil // Default if no validation data } scores := make([]float64, len(validationVectors)) for i, vector := range validationVectors { score, err := core.PredictScore(vector, weights) if err != nil { // This should not happen with valid data, but as a fallback: return 0.5, nil } scores[i] = score } // Collect score distributions by label var posScores, negScores []float64 for i, score := range scores { if validationLabels[i] == 1.0 { posScores = append(posScores, score) } else { negScores = append(negScores, score) } } // Compute stats for each class posStats := computeScoreStats(posScores) negStats := computeScoreStats(negScores) // Calculate Cohen's d (effect size) to measure class separation in the learned space posMean := posStats["mean"] negMean := negStats["mean"] posStd := posStats["std"] negStd := negStats["std"] var cohensD float64 if posStd > 0 && negStd > 0 { pooledStd := math.Sqrt((posStd*posStd + negStd*negStd) / 2) cohensD = math.Abs(posMean-negMean) / pooledStd } // Calculate separation ratio to understand how much the classes overlap on the score scale totalRange := math.Max(posStats["max"], negStats["max"]) - math.Min(posStats["min"], negStats["min"]) overlapStart := math.Max(posStats["min"], negStats["min"]) overlapEnd := math.Min(posStats["max"], negStats["max"]) overlapRange := math.Max(0, overlapEnd-overlapStart) separationRatio := 0.0 if totalRange > 0 { separationRatio = (totalRange - overlapRange) / totalRange } // Find threshold that balances false positives and false negatives using Youden's J. // This metric (Sensitivity + Specificity - 1) equally weights both false positive // and false negative rates. Why not F1? F1 biases toward precision when classes // are imbalanced; a validation set of 10 positives and 1000 negatives would push // the threshold too high. Youden's J treats both types of error equally, which // better reflects real use: missing a relevant article (false negative) is as bad // as showing an irrelevant one (false positive). bestCombinedScore := -1.0 bestThreshold := 0.5 var bestMetrics ClassificationMetrics boolLabels := make([]bool, len(validationLabels)) for i, l := range validationLabels { boolLabels[i] = l == 1.0 } for i := 5; i <= 95; i++ { threshold := float64(i) / 100.0 metrics := computeMetrics(scores, boolLabels, threshold) sensitivity := metrics.Recall // TPR: TP / (TP + FN) specificity := 0.0 if metrics.TrueNegatives+metrics.FalsePositives > 0 { specificity = float64(metrics.TrueNegatives) / float64(metrics.TrueNegatives+metrics.FalsePositives) } youdenJ := sensitivity + specificity - 1.0 if youdenJ > bestCombinedScore { bestCombinedScore = youdenJ bestThreshold = threshold bestMetrics = metrics } } distributions := map[string]any{ "positive": posStats, "negative": negStats, "cohens_d": cohensD, "separation_ratio": separationRatio, "best_f1": bestMetrics.F1Score, "best_precision": bestMetrics.Precision, "best_recall": bestMetrics.Recall, } return bestThreshold, distributions } // computeScoreStats computes min, max, mean, and std for a slice of scores func computeScoreStats(scores []float64) map[string]float64 { if len(scores) == 0 { return map[string]float64{ "min": 0.0, "max": 0.0, "mean": 0.0, "std": 0.0, } } min, max := scores[0], scores[0] sum := 0.0 for _, score := range scores { if score < min { min = score } if score > max { max = score } sum += score } mean := sum / float64(len(scores)) // Calculate standard deviation variance := 0.0 for _, score := range scores { diff := score - mean variance += diff * diff } variance /= float64(len(scores)) std := math.Sqrt(variance) return map[string]float64{ "min": min, "max": max, "mean": mean, "std": std, } } // computeMetrics calculates classification metrics func computeMetrics(scores []float64, labels []bool, threshold float64) ClassificationMetrics { var metrics ClassificationMetrics for i, score := range scores { predicted := score > threshold actual := labels[i] if predicted && actual { metrics.TruePositives++ } else if predicted && !actual { metrics.FalsePositives++ } else if !predicted && actual { metrics.FalseNegatives++ } else { metrics.TrueNegatives++ } } metrics.Calculate() return metrics } // ============================================================================ // ╻ ╻┏━╸╻ ┏━┓┏━╸┏━┓┏━┓ // ┣━┫┣╸ ┃ ┣━┛┣╸ ┣┳┛┗━┓ // ╹ ╹┗━╸┗━╸╹ ┗━╸╹┗╸┗━┛ // ============================================================================ // shortURL formats a URL to be human-readable and not too long func shortURL(urlStr string) string { u, err := url.Parse(urlStr) if err != nil { return urlStr } path := u.Path if len(path) > 30 { path = path[:30] + "..." } return u.Host + path }