Skip to content

Commit

Permalink
Define macro helper
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 8, 2024
1 parent 77dbc0b commit 4bba020
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
49 changes: 41 additions & 8 deletions src/macro_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ function materialize_args(expr::Expr)
return (expr.args[2], expr.args[3])
end

function fused(expr) end

macro fused(expr)
_pairs = gensym()
quote
Expand Down Expand Up @@ -41,15 +39,50 @@ end

fused_pairs(expr::Expr) = Meta.parse(_fused_pairs(expr))

macro fused_multibroadcast(expr)
esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr))
"""
@make_fused type_name fused_named
This macro
- Imports MultiBroadcastFusion
- Defines a type, `type_name`
- Defines a macro, `@fused_name`
This allows users to flexibility
to customize their broadcast fusion.
# Example
```julia
import MultiBroadcastFusion as MBF
MBF.@make_fused MyFusedBroadcast my_fused
Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
@my_fused begin
@. y1 = x1
@. y2 = x1
end
```
"""
macro make_fused(type_name, fused_name)
macrodef = quote
import MultiBroadcastFusion
macro $fused_name(expr)
Base.copyto!($type_name(MultiBroadcastFusion.fused_pairs(expr)))
end
end

macro fused_multibroadcast(fmb, expr)
esc(fused_multibroadcast(fmb, expr))
esc(quote
struct $type_name{T}
pairs::T
end
$macrodef
end)
end
fused_multibroadcast(fmb, expr::Expr) =
Meta.parse("$(fmb)($(_fused_pairs(expr)))")

function build_expr(s::String, code_remain)
n_subs = count("%", s)
Expand Down
13 changes: 0 additions & 13 deletions test/expr_fused_pairs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,3 @@ import MultiBroadcastFusion as MBF
))
@test MBF.fused_pairs(expr_in) == expr_out
end

@testset "fused_multibroadcast" begin
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
end

expr_out = :(MultiBroadcastFusion.FusedMultiBroadcast((
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
)))
@test MBF.fused_multibroadcast(MBF.FusedMultiBroadcast, expr_in) == expr_out
end

0 comments on commit 4bba020

Please sign in to comment.