aboutsummaryrefslogtreecommitdiff
path: root/cmds/scan.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmds/scan.go')
-rw-r--r--cmds/scan.go416
1 files changed, 416 insertions, 0 deletions
diff --git a/cmds/scan.go b/cmds/scan.go
new file mode 100644
index 0000000..789157c
--- /dev/null
+++ b/cmds/scan.go
@@ -0,0 +1,416 @@
+// Scan command: filters articles using trained model.
+//
+// takes articles from RSS feed, text, or JSONL. Scores & outputs those passing.
+// Batches processing (default 50) to allow continuous streaming.
+package cmds
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/mmcdole/gofeed"
+ "scholscan/core"
+)
+
+
+// ============================================================================
+// ┏━╸┏━┓┏┳┓┏┳┓┏━┓┏┓╻╺┳┓
+// ┃ ┃ ┃┃┃┃┃┃┃┣━┫┃┗┫ ┃┃
+// ┗━╸┗━┛╹ ╹╹ ╹╹ ╹╹ ╹╺┻┛
+// ============================================================================
+
+
+// scores articles with trained model and outputs filtered results above thresh
+type ScanCommand struct {
+ URL string
+ FromText bool
+ FromArticles bool
+
+ ModelPath string
+ Threshold string
+
+ MinTitleLength int
+ ChunkSize int
+
+ EventsOut string
+ MetricsOut string
+ Verbose bool
+}
+
+func (c *ScanCommand) Name() string { return "scan" }
+
+func (c *ScanCommand) Init(args []string) error {
+ fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
+ fs.Usage = func() {
+ fmt.Fprint(fs.Output(), `Usage: scholscan scan [options]
+
+Fetches articles, scores with model, outputs matched (>thresh) ones.
+
+Source options (exactly one required):
+ --url <feed_url> Fetch articles from RSS/Atom feed
+ --from-text Extract URLs from text on stdin
+ --from-articles Use Article JSONL from stdin directly
+
+Model and filtering:
+ --model <path> Path to trained model JSON file (required)
+ --threshold <float> Score threshold (if not provided, uses model's recommended threshold)
+
+Enrichment options:
+`)
+ fs.PrintDefaults()
+ fmt.Fprint(fs.Output(), `
+Examples:
+ scholscan scan --url "http://some.blog/rss.xml" --model model.json > interesting.jsonl
+ echo "see https://example.com" | scholscan scan --from-text --model model.json
+ cat articles.jsonl | scholscan scan --from-articles --model model.json
+`)
+ }
+
+ fs.StringVar(&c.URL, "url", "", "RSS/Atom feed URL to fetch")
+ fs.BoolVar(&c.FromText, "from-text", false, "Extract URLs from text on stdin")
+ fs.BoolVar(&c.FromArticles, "from-articles", false, "Use Article JSONL from stdin")
+ fs.StringVar(&c.ModelPath, "model", "", "Path to trained model JSON file (required)")
+ fs.StringVar(&c.Threshold, "threshold", "", "Score threshold for filtering (if not provided, uses model's recommended threshold)")
+ fs.IntVar(&c.MinTitleLength, "min-title-length", core.MinTitleLength, "Minimum title length to consider valid")
+ fs.IntVar(&c.ChunkSize, "chunk-size", core.DefaultChunkSize, "Number of articles to process in each batch")
+ fs.StringVar(&c.EventsOut, "events-out", "events.jsonl", "Write per-article events to a JSONL file")
+ fs.StringVar(&c.MetricsOut, "metrics-out", "metrics.json", "Write summary metrics to a JSON file")
+ fs.BoolVar(&c.Verbose, "verbose", false, "Show progress information")
+
+ if err := fs.Parse(args); err != nil {
+ return err
+ }
+
+ if fs.NArg() != 0 {
+ return fmt.Errorf("unexpected arguments provided: %v", fs.Args())
+ }
+
+ // one src opt required
+ sourceCount := 0
+ if c.URL != "" {
+ sourceCount++
+ }
+ if c.FromText {
+ sourceCount++
+ }
+ if c.FromArticles {
+ sourceCount++
+ }
+
+ if sourceCount == 0 {
+ return fmt.Errorf("exactly one source option must be specified: --url, --from-text, or --from-articles")
+ }
+ if sourceCount > 1 {
+ return fmt.Errorf("only one source option may be specified: --url, --from-text, or --from-articles")
+ }
+
+ if c.ModelPath == "" {
+ return fmt.Errorf("--model flag is required")
+ }
+
+ // prevent dir traversal
+ if strings.Contains(filepath.Clean(c.ModelPath), "..") {
+ return fmt.Errorf("invalid model path: directory traversal not allowed")
+ }
+
+ if c.URL != "" {
+ if _, err := url.Parse(c.URL); err != nil {
+ return fmt.Errorf("invalid URL format: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// Run runs the scan: load the model, decide on a threshold, get articles, then score them in chunks.
+// We bail out early on config problems but try to keep going even if some articles fail to fetch.
+func (c *ScanCommand) Run(stdin io.Reader, stdout io.Writer) error {
+ if c.Verbose {
+ log.SetOutput(os.Stderr)
+ log.Println("Starting scan workflow...")
+ log.Printf("Source: %v", c.getSourceDescription())
+ log.Printf("Model: %s", c.ModelPath)
+ }
+
+ model, err := c.loadModel()
+ if err != nil {
+ return fmt.Errorf("failed to load model: %w", err)
+ }
+
+ threshold, err := c.getThreshold(model)
+ if err != nil {
+ return fmt.Errorf("failed to determine threshold: %w", err)
+ }
+
+ if c.Verbose {
+ log.Printf("Using threshold: %.3f", threshold)
+ }
+
+ var articles []*core.Article
+ if c.FromArticles {
+ articles, err = c.readArticlesFromStdin(stdin)
+ } else {
+ articles, err = c.fetchArticles()
+ }
+ if err != nil {
+ return fmt.Errorf("failed to get articles: %w", err)
+ }
+
+ if c.Verbose {
+ log.Printf("Processing %d articles", len(articles))
+ }
+
+ // process articles in chunks
+ return c.processArticles(articles, model, threshold, stdout, stdin)
+}
+
+
+// ============================================================================
+// ┏┳┓┏━┓╺┳┓┏━╸╻ ┏┓ ┏━╸┏━┓┏┓╻┏━╸╻┏━╸
+// ┃┃┃┃ ┃ ┃┃┣╸ ┃ ┃╺╋╸ ┃ ┃ ┃┃┗┫┣╸ ┃┃╺┓
+// ╹ ╹┗━┛╺┻┛┗━╸┗━╸ ┗━┛ ┗━╸┗━┛╹ ╹╹ ╹┗━┛
+// ============================================================================
+
+
+
+func (c *ScanCommand) getSourceDescription() string {
+ if c.URL != "" {
+ return fmt.Sprintf("RSS feed: %s", c.URL)
+ }
+ if c.FromText {
+ return "text from stdin"
+ }
+ if c.FromArticles {
+ return "articles from stdin"
+ }
+ return "unknown"
+}
+
+// loadModel reads and parses the model JSON file.
+// The envelope contains weights, vocabulary, and optionally a recommended threshold.
+func (c *ScanCommand) loadModel() (*core.ModelEnvelope, error) {
+ f, err := os.Open(c.ModelPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open model file %s: %w", c.ModelPath, err)
+ }
+ defer f.Close()
+
+ var model core.ModelEnvelope
+ if err := json.NewDecoder(f).Decode(&model); err != nil {
+ return nil, fmt.Errorf("failed to decode model: %w", err)
+ }
+
+ return &model, nil
+}
+
+func (c *ScanCommand) getThreshold(model *core.ModelEnvelope) (float64, error) {
+ if c.Threshold != "" {
+ var threshold float64
+ _, err := fmt.Sscanf(c.Threshold, "%f", &threshold)
+ if err == nil {
+ return threshold, nil
+ }
+ }
+
+ if model.Meta != nil {
+ if meta, ok := model.Meta["recommended_threshold"].(float64); ok {
+ return meta, nil
+ }
+ }
+
+ return core.DefaultScoreThreshold, nil
+}
+
+// ============================================================================
+// ┏━┓┏━┓╺┳╸╻┏━╸╻ ┏━╸ ┏━┓┏━┓┏━╸┏━┓
+// ┣━┫┣┳┛ ┃ ┃┃ ┃ ┣╸ ┗━┓┣┳┛┃ ┗━┓
+// ╹ ╹╹┗╸ ╹ ╹┗━╸┗━╸┗━╸ ┗━┛╹┗╸┗━╸┗━┛
+// ============================================================================
+
+
+func (c *ScanCommand) fetchArticles() ([]*core.Article, error) {
+ if c.FromText {
+ return c.extractURLsFromText(os.Stdin)
+ }
+ if c.URL != "" {
+ return c.fetchRSSFeed(c.URL)
+ }
+ return nil, fmt.Errorf("no valid source specified")
+}
+
+// extractURLsFromText pulls URLs from plain text on stdin.
+// We create minimal Article objects since only the URL is needed for scoring.
+func (c *ScanCommand) extractURLsFromText(stdin io.Reader) ([]*core.Article, error) {
+ var urls []string
+ s := bufio.NewScanner(stdin)
+ for s.Scan() {
+ line := s.Text()
+ // url extraction
+ fields := strings.Fields(line)
+ for _, field := range fields {
+ if strings.HasPrefix(field, "http://") || strings.HasPrefix(field, "https://") {
+ urls = append(urls, field)
+ }
+ }
+ }
+
+ // create Article objs for URLs
+ articles := make([]*core.Article, len(urls))
+ for i, url := range urls {
+ articles[i] = &core.Article{
+ URL: url,
+ Title: fmt.Sprintf("Article from %s", url),
+ Content: "",
+ }
+ }
+
+ return articles, s.Err()
+}
+
+// fetchRSSFeed fetches and parses a single RSS feed with a 30s timeout.
+// We skip articles with short titles since they're usually noise or truncated.
+func (c *ScanCommand) fetchRSSFeed(url string) ([]*core.Article, error) {
+ client := &http.Client{Timeout: core.DefaultHTTPTimeout}
+
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("error building request: %w", err)
+ }
+ req.Header.Set("User-Agent", core.PoliteUserAgent)
+
+ ctx, cancel := context.WithTimeout(context.Background(), core.DefaultHTTPTimeout)
+ defer cancel()
+
+ resp, err := client.Do(req.WithContext(ctx))
+ if err != nil {
+ return nil, fmt.Errorf("error fetching %s: %w", url, err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("error reading response from %s: %w", url, err)
+ }
+
+ // parse feed
+ fp := gofeed.NewParser()
+ feed, err := fp.Parse(strings.NewReader(string(body)))
+ if err != nil {
+ return nil, fmt.Errorf("error parsing feed from %s: %w", url, err)
+ }
+
+ var articles []*core.Article
+ for _, item := range feed.Items {
+ article := &core.Article{
+ URL: item.Link,
+ Title: strings.TrimSpace(item.Title),
+ }
+
+ if len(article.Title) >= c.MinTitleLength {
+ articles = append(articles, article)
+ }
+ }
+
+ return articles, nil
+}
+
+// readArticlesFromStdin reads Article objects from JSONL on stdin.
+// Malformed lines are skipped to allow partial processing of corrupted input.
+func (c *ScanCommand) readArticlesFromStdin(stdin io.Reader) ([]*core.Article, error) {
+ var articles []*core.Article
+ decoder := json.NewDecoder(stdin)
+ for {
+ var article core.Article
+ if err := decoder.Decode(&article); err != nil {
+ if err == io.EOF {
+ break
+ }
+ continue
+ }
+
+ if len(article.Title) >= c.MinTitleLength {
+ articles = append(articles, &article)
+ }
+ }
+ return articles, nil
+}
+
+
+
+// ============================================================================
+// ┏━┓┏━┓┏━┓┏━╸┏━╸┏━┓┏━┓ ┏━┓┏━┓╺┳╸╻┏━╸╻ ┏━╸┏━┓
+// ┣━┛┣┳┛┃ ┃┃ ┣╸ ┗━┓┗━┓ ┣━┫┣┳┛ ┃ ┃┃ ┃ ┣╸ ┗━┓
+// ╹ ╹┗╸┗━┛┗━╸┗━╸┗━┛┗━┛ ╹ ╹╹┗╸ ╹ ╹┗━╸┗━╸┗━╸┗━┛
+// ============================================================================
+
+
+// processArticles handles scoring and filtering in batches to keep memory usage predictable.
+// Scoring errors don't crash the process - we log them and continue with the next article.
+func (c *ScanCommand) processArticles(articles []*core.Article, model *core.ModelEnvelope, threshold float64, stdout io.Writer, stdin io.Reader) error {
+ vectorizer := core.CreateVectorizerFromModel(model)
+
+ encoder := json.NewEncoder(stdout)
+
+ // process each batch
+ for i := 0; i < len(articles); i += c.ChunkSize {
+ end := i + c.ChunkSize
+ if end > len(articles) {
+ end = len(articles)
+ }
+
+ chunk := articles[i:end]
+ if c.Verbose {
+ log.Printf("Processing chunk %d-%d of %d articles", i+1, end, len(articles))
+ }
+
+ // calc score for batch
+ docs := make([]string, len(chunk))
+ for j, article := range chunk {
+ docs[j] = strings.TrimSpace(article.Title)
+ }
+
+ vectors := vectorizer.Transform(docs)
+ scores := make([]float64, len(chunk))
+
+ for j, vector := range vectors {
+ score, err := core.PredictScore(vector, model.Weights)
+ if err != nil {
+ log.Printf("Error computing score for article %d: %v", i+j, err)
+ scores[j] = 0.0
+ } else {
+ scores[j] = score
+ }
+ }
+
+ for j, article := range chunk {
+ score := scores[j]
+ article.Score = &score
+
+ if score >= threshold {
+ if err := encoder.Encode(article); err != nil {
+ log.Printf("Error encoding article: %v", err)
+ }
+ }
+ }
+ }
+
+ if c.Verbose {
+ log.Println("Scan complete")
+ }
+
+ return nil
+}