Skip to content

Commit

Permalink
Merge pull request #13 from CliMA/ck/fused_pairs_flexible
Browse files Browse the repository at this point in the history
Add support for fused pairs across barriers
  • Loading branch information
charleskawczynski authored Apr 11, 2024
2 parents 66dd057 + 5da9e66 commit 3f3b3d9
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MultiBroadcastFusion"
uuid = "c3c07f87-98de-43f2-a76f-835b330b2cbb"
authors = ["CliMA Contributors <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[compat]
julia = "^1.10"
Expand Down
51 changes: 50 additions & 1 deletion src/MultiBroadcastFusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, i...) =
rcopyto_at!(first(pairs), i...)
@inline rcopyto_at!(pairs::Tuple{}, i...) = nothing

include("macro_utils.jl")
include("utils.jl")
include("code_lowered_single_expression.jl")
include("fused_pairs.jl")
include("fused_pairs_flexible.jl")

"""
@make_fused type_name fused_named
Expand Down Expand Up @@ -61,4 +64,50 @@ macro make_fused(type_name, fused_name)
end
end

"""
@make_fused_flexible type_name fused_named
This macro
- Imports MultiBroadcastFusion
- Defines a type, `type_name`
- Defines a macro, `@fused_name`
This allows users to flexibility
to customize their broadcast fusion.
# Example
```julia
import MultiBroadcastFusion as MBF
MBF.@make_fused_flexible MyFusedBroadcast my_fused
Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
@my_fused begin
@. y1 = x1
@. y2 = x1
end
```
"""
macro make_fused_flexible(type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
struct $t{T <: Tuple}
pairs::T
end
macro $f(expr)
_pairs = esc($(fused_pairs_flexible)(expr, gensym()))
t = $t
quote
Base.copyto!($t($_pairs))
end
end
end
end

end # module MultiBroadcastFusion
15 changes: 15 additions & 0 deletions src/code_lowered_single_expression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# General case: do nothing (identity)
substitute(x, code) = x
substitute(x::Core.SSAValue, code) = substitute(code[x.id], code)
substitute(x::Core.ReturnNode, code) = substitute(code[x.val.id], code)
substitute(s::Symbol, code) = s
# Expression: recursively substitute for Expr
substitute(e::Expr, code) =
Expr(substitute(e.head, code), substitute.(e.args, Ref(code))...)

code_info(expr) = Base.Meta.lower(Main, expr).args[1]
function code_lowered_single_expression(expr)
code = code_info(expr).code # vector
s = string(substitute(code[end], code))
return Base.Meta.parse(s)
end
59 changes: 59 additions & 0 deletions src/fused_pairs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#####
##### Simple version
#####

# General case: do nothing (identity)
transform(x) = x
transform(x::Core.SSAValue) = transform(code[x.id])
transform(x::Core.ReturnNode) = transform(code[x.val.id])
transform(s::Symbol) = s
# Expression: recursively transform for Expr
function transform(e::Expr)
if e.head == :macrocall && e.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :(Pair($(margs[1]), $(margs[2])))
subexpr
else
Expr(transform(e.head), transform.(e.args)...)
end
end

function fused_pairs(expr::Expr)
check_restrictions(expr)
e = transform(expr)
@assert e.head == :block
ex = Expr(:call, :tuple, e.args...)
# Filter out LineNumberNode, as this will not be valid due to prepending `tup = ()`
linefilter!(ex)
ex
end

function check_restrictions(expr::Expr)
for _expr in expr.args
_expr isa LineNumberNode && continue
s_error = if _expr isa QuoteNode
"Dangling symbols are not allowed inside fused blocks"
elseif _expr.head == :for
"Loops are not allowed inside fused blocks"
elseif _expr.head == :if
"If-statements are not allowed inside fused blocks"
elseif _expr.head == :call
"Function calls are not allowed inside fused blocks"
elseif _expr.head == :(=)
"Non-broadcast assignments are not allowed inside fused blocks"
elseif _expr.head == :let
"Let-blocks are not allowed inside fused blocks"
elseif _expr.head == :quote
"Quotes are not allowed inside fused blocks"
else
""
end
isempty(s_error) || error(s_error)
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__")
else
@show dump(_expr)
error("Uncaught edge case")
end
end
end
62 changes: 62 additions & 0 deletions src/fused_pairs_flexible.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#####
##### Complex/flexible version
#####

# General case: do nothing (identity)
transform_flex(x, sym) = x
transform_flex(x::Core.SSAValue, sym) = transform_flex(code[x.id], sym)
transform_flex(x::Core.ReturnNode, sym) = transform_flex(code[x.val.id], sym)
transform_flex(s::Symbol, sym) = s
# Expression: recursively transform_flex for Expr
function transform_flex(e::Expr, sym)
if e.head == :macrocall && e.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr
else
Expr(transform_flex(e.head, sym), transform_flex.(e.args, sym)...)
end
end

function fused_pairs_flexible(expr::Expr, sym::Symbol)
check_restrictions_flexible(expr)
e = transform_flex(expr, sym)
@assert e.head == :block
ex = Expr(:block, :($sym = ()), e.args..., sym)
# Filter out LineNumberNode, as this will not be valid due to prepending `tup = ()`
linefilter!(ex)
ex
end

function check_restrictions_flexible(expr::Expr)
for arg in expr.args
arg isa LineNumberNode && continue
s_error = if arg isa QuoteNode
"Dangling symbols are not allowed inside fused blocks"
elseif arg.head == :call
"Function calls are not allowed inside fused blocks"
elseif arg.head == :(=)
"Non-broadcast assignments are not allowed inside fused blocks"
elseif arg.head == :let
"Let-blocks are not allowed inside fused blocks"
elseif arg.head == :quote
"Quotes are not allowed inside fused blocks"
else
""
end
isempty(s_error) || error(s_error)

if arg.head == :macrocall && arg.args[1] == Symbol("@__dot__")
elseif arg.head == :for
check_restrictions(arg.args[2])
elseif arg.head == :if
check_restrictions(arg.args[2])
elseif arg.head == :macrocall && arg.args[1] == Symbol("@inbounds")
else
@show dump(arg)
error("Uncaught edge case")
end
end
return nothing
end
72 changes: 0 additions & 72 deletions src/macro_utils.jl

This file was deleted.

36 changes: 36 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#####
##### Helper
#####

# Recursively remove LineNumberNode from an `Expr`
@noinline function linefilter!(expr::Expr)
total = length(expr.args)
i = 0
while i < total
i += 1
if expr.args[i] |> typeof == Expr
if expr.args[i].head == :line
deleteat!(expr.args, i)
total -= 1
i -= 1
else
expr.args[i] = linefilter!(expr.args[i])
end
elseif expr.args[i] |> typeof == LineNumberNode
if expr.head == :macrocall
expr.args[i] = nothing
else
deleteat!(expr.args, i)
total -= 1
i -= 1
end
end
end
return expr
end

function materialize_args(expr::Expr)
@assert expr.head == :call
@assert expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
end
4 changes: 2 additions & 2 deletions test/expr_errors_and_edge_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
@. y2 = x2 + x3 + x4 + x5
end

expr_out = :((
expr_out = :(tuple(
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
))
Expand All @@ -132,5 +132,5 @@ end

@testset "Empty" begin
expr_in = quote end
@test MBF.fused_pairs(expr_in) == :(())
@test MBF.fused_pairs(expr_in) == :(tuple())
end
Loading

2 comments on commit 3f3b3d9

@charleskawczynski
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/104726

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" 3f3b3d90799b1c21605a19bfd84cbdf43f455898
git push origin v0.1.1

Please sign in to comment.