diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index 90e648fe2c..8096a868af 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -8,7 +8,6 @@ import ( "github.com/metacubex/mihomo/common/nnip" "github.com/metacubex/mihomo/component/profile/cachefile" - "github.com/metacubex/mihomo/component/trie" C "github.com/metacubex/mihomo/constant" ) @@ -36,8 +35,7 @@ type Pool struct { offset netip.Addr cycle bool mux sync.Mutex - host *trie.DomainTrie[struct{}] - rules []C.Rule + host []C.Rule ipnet netip.Prefix store store } @@ -68,14 +66,8 @@ func (p *Pool) LookBack(ip netip.Addr) (string, bool) { // ShouldSkipped return if domain should be skipped func (p *Pool) ShouldSkipped(domain string) bool { - if p.host != nil { - if p.host.Search(domain) != nil { - return true - } - } - for _, rule := range p.rules { - metadata := &C.Metadata{Host: domain} - if match, _ := rule.Match(metadata); match { + for _, rule := range p.host { + if match, _ := rule.Match(&C.Metadata{Host: domain}); match { return true } } @@ -164,9 +156,7 @@ func (p *Pool) restoreState() { type Options struct { IPNet netip.Prefix - Host *trie.DomainTrie[struct{}] - - Rules []C.Rule + Host []C.Rule // Size sets the maximum number of entries in memory // and does not work if Persistence is true @@ -197,7 +187,6 @@ func New(options Options) (*Pool, error) { offset: first.Prev(), cycle: false, host: options.Host, - rules: options.Rules, ipnet: options.IPNet, } if options.Persistence { diff --git a/component/fakeip/pool_test.go b/component/fakeip/pool_test.go index cc50fcf7b5..4d4f8c1931 100644 --- a/component/fakeip/pool_test.go +++ b/component/fakeip/pool_test.go @@ -9,6 +9,8 @@ import ( "github.com/metacubex/mihomo/component/profile/cachefile" "github.com/metacubex/mihomo/component/trie" + C "github.com/metacubex/mihomo/constant" + RC "github.com/metacubex/mihomo/rules/common" "github.com/sagernet/bbolt" "github.com/stretchr/testify/assert" @@ -154,7 +156,7 @@ func TestPool_Skip(t *testing.T) { pools, tempfile, err := createPools(Options{ IPNet: ipnet, Size: 10, - Host: tree, + Host: []C.Rule{RC.NewDomainSet(tree.NewDomainSet(), "")}, }) assert.Nil(t, err) defer os.Remove(tempfile) diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 4438638dad..c96f5a4b03 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -9,7 +9,6 @@ import ( "github.com/metacubex/mihomo/common/lru" N "github.com/metacubex/mihomo/common/net" - "github.com/metacubex/mihomo/component/trie" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/constant/sniffer" "github.com/metacubex/mihomo/log" @@ -26,17 +25,26 @@ var Dispatcher *SnifferDispatcher type SnifferDispatcher struct { enable bool sniffers map[sniffer.Sniffer]SnifferConfig - forceDomain *trie.DomainSet - skipSNI *trie.DomainSet + forceDomain []C.Rule + skipDomain []C.Rule skipList *lru.LruCache[string, uint8] forceDnsMapping bool parsePureIp bool } func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool { - return (metadata.Host == "" && sd.parsePureIp) || - sd.forceDomain.Has(metadata.Host) || - (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) + if metadata.Host == "" && sd.parsePureIp { + return true + } + if metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping { + return true + } + for _, rule := range sd.forceDomain { + if ok, _ := rule.Match(&C.Metadata{Host: metadata.Host}); ok { + return true + } + } + return false } func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool { @@ -94,9 +102,11 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort) return false } else { - if sd.skipSNI.Has(host) { - log.Debugln("[Sniffer] Skip sni[%s]", host) - return false + for _, rule := range sd.skipDomain { + if ok, _ := rule.Match(&C.Metadata{Host: host}); ok { + log.Debugln("[Sniffer] Skip sni[%s]", host) + return false + } } sd.skipList.Delete(dst) @@ -187,12 +197,12 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) { } func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, - forceDomain *trie.DomainSet, skipSNI *trie.DomainSet, + forceDomain []C.Rule, skipDomain []C.Rule, forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) { dispatcher := SnifferDispatcher{ enable: true, forceDomain: forceDomain, - skipSNI: skipSNI, + skipDomain: skipDomain, skipList: lru.New(lru.WithSize[string, uint8](128), lru.WithAge[string, uint8](600)), forceDnsMapping: forceDnsMapping, parsePureIp: parsePureIp, diff --git a/component/trie/domain.go b/component/trie/domain.go index db30402ede..6d3e37f70a 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -134,6 +134,13 @@ func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) { } } +func (t *DomainTrie[T]) IsEmpty() bool { + if t == nil { + return true + } + return t.root.isEmpty() +} + func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool { for key, data := range node.getChildren() { newItems := append([]string{key}, items...) diff --git a/config/config.go b/config/config.go index fae88e1a04..2685013717 100644 --- a/config/config.go +++ b/config/config.go @@ -164,8 +164,8 @@ type IPTables struct { type Sniffer struct { Enable bool Sniffers map[snifferTypes.Type]SNIFF.SnifferConfig - ForceDomain *trie.DomainSet - SkipDomain *trie.DomainSet + ForceDomain []C.Rule + SkipDomain []C.Rule ForceDnsMapping bool ParsePureIp bool } @@ -627,7 +627,7 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { } } - config.Sniffer, err = parseSniffer(rawCfg.Sniffer) + config.Sniffer, err = parseSniffer(rawCfg.Sniffer, rules, ruleProviders) if err != nil { return nil, err } @@ -1408,87 +1408,27 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul return nil, err } - var host *trie.DomainTrie[struct{}] - var fakeIPRules []C.Rule - // fake ip skip host filter - if len(cfg.FakeIPFilter) != 0 { - host = trie.New[struct{}]() - for _, domain := range cfg.FakeIPFilter { - if strings.Contains(strings.ToLower(domain), ",") { - if strings.Contains(domain, "geosite:") { - subkeys := strings.Split(domain, ":") - subkeys = subkeys[1:] - subkeys = strings.Split(subkeys[0], ",") - for _, country := range subkeys { - found := false - for _, rule := range rules { - if rule.RuleType() == C.GEOSITE { - if strings.EqualFold(country, rule.Payload()) { - found = true - fakeIPRules = append(fakeIPRules, rule) - } - } - } - if !found { - rule, err := RC.NewGEOSITE(country, "") - if err != nil { - return nil, err - } - fakeIPRules = append(fakeIPRules, rule) - } - } - - } - } else if strings.Contains(strings.ToLower(domain), "rule-set:") { - subkeys := strings.Split(domain, ":") - subkeys = subkeys[1:] - subkeys = strings.Split(subkeys[0], ",") - for _, domainSetName := range subkeys { - if rp, ok := ruleProviders[domainSetName]; !ok { - return nil, fmt.Errorf("not found rule-set: %s", domainSetName) - } else { - switch rp.Behavior() { - case providerTypes.IPCIDR: - return nil, fmt.Errorf("rule provider type error, except domain,actual %s", rp.Behavior()) - case providerTypes.Classical: - log.Warnln("%s provider is %s, only matching it contain domain rule", rp.Name(), rp.Behavior()) - default: - } - } - rule, err := RP.NewRuleSet(domainSetName, "", true) - if err != nil { - return nil, err - } - - fakeIPRules = append(fakeIPRules, rule) - } - } else { - _ = host.Insert(domain, struct{}{}) - } - } - } - + var fakeIPTrie *trie.DomainTrie[struct{}] if len(dnsCfg.Fallback) != 0 { - if host == nil { - host = trie.New[struct{}]() - } + fakeIPTrie = trie.New[struct{}]() for _, fb := range dnsCfg.Fallback { if net.ParseIP(fb.Addr) != nil { continue } - _ = host.Insert(fb.Addr, struct{}{}) + _ = fakeIPTrie.Insert(fb.Addr, struct{}{}) } } - if host != nil { - host.Optimize() + // fake ip skip host filter + host, err := parseDomain(cfg.FakeIPFilter, fakeIPTrie, rules, ruleProviders) + if err != nil { + return nil, err } pool, err := fakeip.New(fakeip.Options{ IPNet: fakeIPRange, Size: 1000, Host: host, - Rules: fakeIPRules, Persistence: rawCfg.Profile.StoreFakeIP, }) if err != nil { @@ -1609,7 +1549,7 @@ func parseTuicServer(rawTuic RawTuicServer, general *General) error { return nil } -func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { +func parseSniffer(snifferRaw RawSniffer, rules []C.Rule, ruleProviders map[string]providerTypes.RuleProvider) (*Sniffer, error) { sniffer := &Sniffer{ Enable: snifferRaw.Enable, ForceDnsMapping: snifferRaw.ForceDnsMapping, @@ -1672,23 +1612,83 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { sniffer.Sniffers = loadSniffer - forceDomainTrie := trie.New[struct{}]() - for _, domain := range snifferRaw.ForceDomain { - err := forceDomainTrie.Insert(domain, struct{}{}) - if err != nil { - return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) - } + forceDomain, err := parseDomain(snifferRaw.ForceDomain, nil, rules, ruleProviders) + if err != nil { + return nil, fmt.Errorf("error in force-domain, error:%w", err) } - sniffer.ForceDomain = forceDomainTrie.NewDomainSet() + sniffer.ForceDomain = forceDomain - skipDomainTrie := trie.New[struct{}]() - for _, domain := range snifferRaw.SkipDomain { - err := skipDomainTrie.Insert(domain, struct{}{}) - if err != nil { - return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err) - } + skipDomain, err := parseDomain(snifferRaw.SkipDomain, nil, rules, ruleProviders) + if err != nil { + return nil, fmt.Errorf("error in skip-domain, error:%w", err) } - sniffer.SkipDomain = skipDomainTrie.NewDomainSet() + sniffer.SkipDomain = skipDomain return sniffer, nil } + +func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], rules []C.Rule, ruleProviders map[string]providerTypes.RuleProvider) (domainRules []C.Rule, err error) { + var rule C.Rule + for _, domain := range domains { + domainLower := strings.ToLower(domain) + if strings.Contains(domainLower, "geosite:") { + subkeys := strings.Split(domain, ":") + subkeys = subkeys[1:] + subkeys = strings.Split(subkeys[0], ",") + for _, country := range subkeys { + found := false + for _, rule = range rules { + if rule.RuleType() == C.GEOSITE { + if strings.EqualFold(country, rule.Payload()) { + found = true + domainRules = append(domainRules, rule) + } + } + } + if !found { + rule, err = RC.NewGEOSITE(country, "") + if err != nil { + return nil, err + } + domainRules = append(domainRules, rule) + } + } + } else if strings.Contains(domainLower, "rule-set:") { + subkeys := strings.Split(domain, ":") + subkeys = subkeys[1:] + subkeys = strings.Split(subkeys[0], ",") + for _, domainSetName := range subkeys { + if rp, ok := ruleProviders[domainSetName]; !ok { + return nil, fmt.Errorf("not found rule-set: %s", domainSetName) + } else { + switch rp.Behavior() { + case providerTypes.IPCIDR: + return nil, fmt.Errorf("rule provider type error, except domain,actual %s", rp.Behavior()) + case providerTypes.Classical: + log.Warnln("%s provider is %s, only matching it contain domain rule", rp.Name(), rp.Behavior()) + default: + } + } + rule, err = RP.NewRuleSet(domainSetName, "", true) + if err != nil { + return nil, err + } + + domainRules = append(domainRules, rule) + } + } else { + if domainTrie == nil { + domainTrie = trie.New[struct{}]() + } + err = domainTrie.Insert(domain, struct{}{}) + if err != nil { + return nil, err + } + } + } + if !domainTrie.IsEmpty() { + rule = RC.NewDomainSet(domainTrie.NewDomainSet(), "") + domainRules = append(domainRules, rule) + } + return +} diff --git a/rules/common/domain_set.go b/rules/common/domain_set.go new file mode 100644 index 0000000000..95f5896aef --- /dev/null +++ b/rules/common/domain_set.go @@ -0,0 +1,41 @@ +package common + +import ( + "github.com/metacubex/mihomo/component/trie" + C "github.com/metacubex/mihomo/constant" +) + +type DomainSet struct { + *Base + domainSet *trie.DomainSet + adapter string +} + +func (d *DomainSet) RuleType() C.RuleType { + return C.Domain +} + +func (d *DomainSet) Match(metadata *C.Metadata) (bool, string) { + if d.domainSet == nil { + return false, "" + } + return d.domainSet.Has(metadata.RuleHost()), d.adapter +} + +func (d *DomainSet) Adapter() string { + return d.adapter +} + +func (d *DomainSet) Payload() string { + return "" +} + +func NewDomainSet(domainSet *trie.DomainSet, adapter string) *DomainSet { + return &DomainSet{ + Base: &Base{}, + domainSet: domainSet, + adapter: adapter, + } +} + +var _ C.Rule = (*DomainSet)(nil)