diff --git a/src/macro_utils.jl b/src/macro_utils.jl index 7afc3e4..8cf16d6 100644 --- a/src/macro_utils.jl +++ b/src/macro_utils.jl @@ -5,8 +5,6 @@ function materialize_args(expr::Expr) return (expr.args[2], expr.args[3]) end -function fused(expr) end - macro fused(expr) _pairs = gensym() quote @@ -41,15 +39,50 @@ end fused_pairs(expr::Expr) = Meta.parse(_fused_pairs(expr)) -macro fused_multibroadcast(expr) - esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr)) +""" + @make_fused type_name fused_named + +This macro + - Imports MultiBroadcastFusion + - Defines a type, `type_name` + - Defines a macro, `@fused_name` + +This allows users to flexibility +to customize their broadcast fusion. + +# Example +```julia +import MultiBroadcastFusion as MBF +MBF.@make_fused MyFusedBroadcast my_fused + +Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!") + +x1 = rand(3,3) +y1 = rand(3,3) +y2 = rand(3,3) + +# 4 reads, 2 writes +@my_fused begin + @. y1 = x1 + @. y2 = x1 end +``` +""" +macro make_fused(type_name, fused_name) + macrodef = quote + import MultiBroadcastFusion + macro $fused_name(expr) + Base.copyto!($type_name(MultiBroadcastFusion.fused_pairs(expr))) + end + end -macro fused_multibroadcast(fmb, expr) - esc(fused_multibroadcast(fmb, expr)) + esc(quote + struct $type_name{T} + pairs::T + end + $macrodef + end) end -fused_multibroadcast(fmb, expr::Expr) = - Meta.parse("$(fmb)($(_fused_pairs(expr)))") function build_expr(s::String, code_remain) n_subs = count("%", s) diff --git a/test/expr_fused_pairs.jl b/test/expr_fused_pairs.jl index 5d3d63d..06c73d5 100644 --- a/test/expr_fused_pairs.jl +++ b/test/expr_fused_pairs.jl @@ -15,16 +15,3 @@ import MultiBroadcastFusion as MBF )) @test MBF.fused_pairs(expr_in) == expr_out end - -@testset "fused_multibroadcast" begin - expr_in = quote - @. y1 = x1 + x2 + x3 + x4 - @. y2 = x2 + x3 + x4 + x5 - end - - expr_out = :(MultiBroadcastFusion.FusedMultiBroadcast(( - Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)), - Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)), - ))) - @test MBF.fused_multibroadcast(MBF.FusedMultiBroadcast, expr_in) == expr_out -end