diff --git a/mie/pass/builtin/builtin.c b/mie/pass/builtin/builtin.c index 19fa546..8b98b75 100644 --- a/mie/pass/builtin/builtin.c +++ b/mie/pass/builtin/builtin.c @@ -4,4 +4,5 @@ MIE_PASS_GROUP_BEGIN(mie_builtin) MIE_PASS_GROUP_ADD_PASS(prefix_func_with_underscore); + MIE_PASS_GROUP_ADD_PASS(convert_scf_to_cf); MIE_PASS_GROUP_END() diff --git a/mie/pass/builtin/convert-scf-to-cf.c b/mie/pass/builtin/convert-scf-to-cf.c new file mode 100644 index 0000000..27750e9 --- /dev/null +++ b/mie/pass/builtin/convert-scf-to-cf.c @@ -0,0 +1,136 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static bool op_is_flat(const struct mie_op *op) +{ + return false; +} + +static enum mie_match_result if_match(const struct mie_op *op) +{ + if (!mie_op_is(op, "scf", "if")) { + return MIE_NO_MATCH_FOUND; + } + + return MIE_MATCH_FOUND; +} + +static struct mie_rewrite_result if_rewrite( + struct mie_op *op, struct mie_rewriter *rewriter) +{ + printf("if: rewriting %p %s.%s\n", op, op->op_dialect->d_name, + op->op_info->op_name); + + struct mie_register *cond = mie_op_get_arg(op, 0); + struct mie_region *parent_region = op->op_container->b_parent; + + struct mie_region *if_region = mie_op_get_first_region(op); + struct mie_block *if_start = mie_region_get_first_block(if_region); + struct mie_block *if_end = mie_region_get_last_block(if_region); + struct mie_region *else_region = mie_op_get_next_region(op, if_region); + struct mie_block *else_start = mie_region_get_first_block(else_region); + + struct mie_block *pre_block = op->op_container; + struct mie_block *end_block + = mie_rewriter_split_block(rewriter, pre_block, op, "if.end"); + + for (size_t i = 0; i < MIE_VECTOR_COUNT(op->op_result); i++) { + struct mie_register *old_reg = &op->op_result.items[i]; + struct mie_register *new_reg = mie_block_add_param(end_block); + new_reg->reg_type = old_reg->reg_type; + char *name = b_strdup(old_reg->reg_name.n_str); + + mie_name_destroy(&old_reg->reg_name); + mie_rewriter_rename_register(rewriter, new_reg, name); + mie_rewriter_replace_register(rewriter, old_reg, new_reg); + free(name); + } + + struct mie_walker walker; + mie_walker_begin(&walker, op, MIE_WALKER_F_INCLUDE_OPS); + + do { + const struct mie_walk_item *item = mie_walker_get(&walker); + if (!mie_op_is(item->i_op, "scf", "yield")) { + continue; + } + + printf("if: found scf.yield %p\n", item->i_op); + + struct mie_op *br = mie_rewriter_replace_op( + rewriter, item->i_op, "cf", "br"); + struct mie_op_successor *s = mie_rewriter_add_op_successor( + rewriter, br, end_block, NULL, 0); + mie_rewriter_move_op_args_to_successor(rewriter, br, s); + + } while (mie_walker_step(&walker) == MIE_SUCCESS); + + mie_walker_end(&walker); + + mie_rewriter_move_blocks_after( + rewriter, if_region, parent_region, pre_block); + mie_rewriter_rename_block(rewriter, if_start, "if.then"); + + if (else_region) { + mie_rewriter_move_blocks_after( + rewriter, else_region, parent_region, if_end); + mie_rewriter_rename_block(rewriter, else_start, "if.else"); + } + + mie_rewriter_erase_op(rewriter, op); + + mie_rewriter_set_insertion_block(rewriter, pre_block); + mie_rewriter_set_insertion_point(rewriter, NULL); + mie_cf_br_cond_put( + MIE_EMITTER(rewriter), cond, if_start, NULL, 0, else_start, + NULL, 0); + + return MIE_REWRITE_RESULT(MIE_REWRITE_SUCCESS, MIE_SUCCESS); +} + +MIE_REWRITE_PATTERN_BEGIN(if_pattern) + MIE_REWRITE_PATTERN_ROOT("scf", "if"); + MIE_REWRITE_PATTERN_MATCH(if_match); + MIE_REWRITE_PATTERN_REWRITE(if_rewrite); +MIE_REWRITE_PATTERN_END() + +static struct mie_pass_result transform( + struct mie_pass *pass, struct mie_op *op, struct mie_pass_args *args) +{ + printf("%s: taking a look at %p %s.%s\n", mie_pass_get_name(pass), op, + op->op_dialect->d_name, op->op_info->op_name); + + struct mie_convert_config *cfg = mie_convert_config_create(args->p_ctx); + mie_convert_config_add_illegal_op(cfg, "scf", "if"); + + struct mie_pattern_set patterns = {}; + if_pattern_create(&patterns); + mie_convert_apply(op, cfg, &patterns); + + mie_pattern_set_cleanup(&patterns); + mie_convert_config_destroy(cfg); + + return MIE_PASS_CONTINUE; +} + +MIE_PASS_DEFINITION_BEGIN(convert_scf_to_cf) + MIE_PASS_NAME("convert-scf-to-cf"); + MIE_PASS_DESCRIPTION( + "Convert high-level SCF constructs to low-level CF operations " + "and blocks."); + MIE_PASS_TRANSFORM(transform); + MIE_PASS_FILTER_OP("func", "func"); +MIE_PASS_DEFINITION_END()