diff --git a/daemon/main.go b/daemon/main.go index e238e7b667..0d757e75d2 100644 --- a/daemon/main.go +++ b/daemon/main.go @@ -34,6 +34,8 @@ import ( "runtime" "runtime/pprof" "runtime/trace" + "strconv" + "strings" "syscall" "time" @@ -68,6 +70,7 @@ var ( ebpfModPath = "" // /usr/lib/opensnitchd/ebpf noLiveReload = false queueNum = 0 + queues = "" repeatQueueNum int //will be set later to queueNum + 1 workers = 16 debug = false @@ -91,7 +94,7 @@ var ( queue = (*netfilter.Queue)(nil) repeatQueue = (*netfilter.Queue)(nil) repeatPktChan = (<-chan netfilter.Packet)(nil) - pktChan = (<-chan netfilter.Packet)(nil) + pktChan = [](<-chan netfilter.Packet)(nil) wrkChan = (chan netfilter.Packet)(nil) sigChan = (chan os.Signal)(nil) exitChan = (chan bool)(nil) @@ -106,6 +109,7 @@ func init() { flag.StringVar(&procmonMethod, "process-monitor-method", procmonMethod, "How to search for processes path. Options: ftrace, audit (experimental), ebpf (experimental), proc (default)") flag.StringVar(&uiSocket, "ui-socket", uiSocket, "Path the UI gRPC service listener (https://github.com/grpc/grpc/blob/master/doc/naming.md).") flag.IntVar(&queueNum, "queue-num", queueNum, "Netfilter queue number.") + flag.StringVar(&queues, "queues", queues, "Netfilter total queues. Format: -queues 1:10 (starts 10 queues)") flag.IntVar(&workers, "workers", workers, "Number of concurrent workers.") flag.BoolVar(&noLiveReload, "no-live-reload", debug, "Disable rules live reloading.") @@ -154,16 +158,10 @@ func overwriteLogging() bool { func setupQueues() { // prepare the queue var err error - queue, err = netfilter.NewQueue(uint16(queueNum)) - if err != nil { - msg := fmt.Sprintf("Error creating queue #%d: %s", queueNum, err) - uiClient.SendWarningAlert(msg) - log.Warning("Is opensnitchd already running?") - log.Fatal(msg) - } - pktChan = queue.Packets() - repeatQueueNum = queueNum + 1 + // use upper range numbers for the repeating queue, not to interfere with + // the queue ranges. + repeatQueueNum = 32000 - queueNum repeatQueue, err = netfilter.NewQueue(uint16(repeatQueueNum)) if err != nil { @@ -173,6 +171,25 @@ func setupQueues() { log.Warning(msg) } repeatPktChan = repeatQueue.Packets() + + // the format to specify multiple queues is 1:10 + qs := strings.SplitN(queues, ":", 2) + lowb := uint64(0) + upb := uint64(1) + if len(qs) > 1 { + lowb, err = strconv.ParseUint(qs[0], 10, 16) + if err != nil { + lowb = 0 + } + upb, err = strconv.ParseUint(qs[1], 10, 16) + if err != nil { + upb = lowb + 1 + } + } + for i := lowb; i < upb; i++ { + q, _ := netfilter.NewQueue(uint16(i)) + pktChan = append(pktChan, q.Packets()) + } } func setupLogging() { @@ -258,12 +275,9 @@ func worker(id int) { case <-ctx.Done(): goto Exit default: - pkt, ok := <-wrkChan - if !ok { - log.Debug("worker channel closed %d", id) - goto Exit - } + pkt := <-wrkChan onPacket(pkt) + } } Exit: @@ -273,7 +287,7 @@ Exit: func setupWorkers() { log.Debug("Starting %d workers ...", workers) // setup the workers - wrkChan = make(chan netfilter.Packet) + wrkChan = make(chan netfilter.Packet, workers) for i := 0; i < workers; i++ { go worker(i) } @@ -638,18 +652,26 @@ func main() { initSystemdResolvedMonitor() log.Info("Running on netfilter queue #%d ...", queueNum) - for { - select { - case <-ctx.Done(): - goto Exit - case pkt, ok := <-pktChan: - if !ok { - goto Exit + for _, p := range pktChan { + go func(c <-chan netfilter.Packet) { + for { + select { + case <-ctx.Done(): + return + case pkt, ok := <-c: + if !ok { + return + } + wrkChan <- pkt + } } - wrkChan <- pkt - } + }(p) } -Exit: + select { + case <-sigChan: + case <-ctx.Done(): + } + close(wrkChan) doCleanup(queue, repeatQueue) os.Exit(0)