pass: convert-scf-to-cf: implement for-loop rewriting

This commit is contained in:
2026-03-16 12:05:14 +00:00
parent 4d315d6abe
commit 8c9cab7408

View File

@@ -1,4 +1,5 @@
#include <mie/ctx.h>
#include <mie/dialect/arith.h>
#include <mie/dialect/builtin.h>
#include <mie/dialect/cf.h>
#include <mie/dialect/dialect.h>
@@ -101,12 +102,144 @@ static struct mie_rewrite_result if_rewrite(
return MIE_REWRITE_RESULT(MIE_REWRITE_SUCCESS, MIE_SUCCESS);
}
static enum mie_match_result for_match(const struct mie_op *op)
{
if (!mie_op_is(op, "scf", "for")) {
return MIE_NO_MATCH_FOUND;
}
return MIE_MATCH_FOUND;
}
static struct mie_rewrite_result for_rewrite(
struct mie_op *op, struct mie_rewriter *rewriter)
{
printf("for: rewriting %p %s.%s\n", op, op->op_info->op_parent->d_name,
op->op_info->op_name);
struct mie_region *parent_region = op->op_container->b_parent;
struct mie_region *for_body = mie_op_get_first_region(op);
struct mie_block *for_entry = mie_region_get_first_block(for_body);
struct mie_block *pre_block = op->op_container;
struct mie_block *end_block
= mie_rewriter_split_block(rewriter, pre_block, op, "for.end");
struct mie_register *entry_iv = &for_entry->b_params.items[0];
struct mie_register *lb = op->op_args.items[0].arg_value.u_reg;
struct mie_register *ub = op->op_args.items[1].arg_value.u_reg;
struct mie_register *step = op->op_args.items[2].arg_value.u_reg;
MIE_VECTOR_DEFINE(struct mie_register *, initial_args);
mie_vector_push_back(initial_args, &lb, NULL);
for (size_t i = 3; i < MIE_VECTOR_COUNT(op->op_args); i++) {
struct mie_register *arg = op->op_args.items[i].arg_value.u_reg;
mie_vector_push_back(initial_args, &arg, NULL);
}
mie_rewriter_set_insertion_block(rewriter, pre_block);
mie_cf_br_put(
MIE_EMITTER(rewriter), for_entry, initial_args.items,
initial_args.count);
char iv_next_name[64];
snprintf(
iv_next_name, sizeof iv_next_name, "%s.next",
entry_iv->reg_name.n_str);
mie_rewriter_set_insertion_block(rewriter, for_entry);
struct mie_block *for_cond
= mie_rewriter_create_block(rewriter, end_block, "for.cond");
for (size_t i = 0; i < MIE_VECTOR_COUNT(for_entry->b_params); i++) {
const char *var_name = for_entry->b_params.items[i].reg_name.n_str;
const struct mie_type *var_type
= for_entry->b_params.items[i].reg_type;
mie_rewriter_add_block_parameter(
rewriter, for_cond, var_name, entry_iv->reg_type);
}
for (size_t i = 0; i < MIE_VECTOR_COUNT(op->op_result); i++) {
struct mie_register *old_reg = &op->op_result.items[i];
struct mie_register *new_reg = mie_block_add_param(end_block);
new_reg->reg_type = old_reg->reg_type;
char *name = b_strdup(old_reg->reg_name.n_str);
mie_name_destroy(&old_reg->reg_name);
mie_rewriter_rename_register(rewriter, new_reg, name);
mie_rewriter_replace_register(rewriter, old_reg, new_reg);
free(name);
}
struct mie_walker walker;
mie_walker_begin(&walker, op, MIE_WALKER_F_INCLUDE_OPS);
do {
const struct mie_walk_item *item = mie_walker_get(&walker);
if (!mie_op_is(item->i_op, "scf", "yield")) {
continue;
}
printf("for: found scf.yield %p\n", item->i_op);
struct mie_op *br = mie_rewriter_replace_op(
rewriter, item->i_op, "cf", "br");
struct mie_op_successor *s = mie_rewriter_add_op_successor(
rewriter, br, for_cond, NULL, 0);
struct mie_op_arg *iv_arg = mie_rewriter_add_op_successor_arg(
rewriter, br, s, entry_iv);
mie_rewriter_move_op_args_to_successor(rewriter, br, s);
} while (mie_walker_step(&walker) == MIE_SUCCESS);
mie_walker_end(&walker);
mie_rewriter_move_blocks_after(rewriter, for_body, parent_region, pre_block);
mie_rewriter_set_insertion_block(rewriter, for_cond);
struct mie_register *cond_iv = &for_cond->b_params.items[0];
struct mie_register *iv_next = mie_arith_addi_put(
MIE_EMITTER(rewriter), cond_iv, step, iv_next_name);
struct mie_register *iv_cmp = mie_arith_cmpi_put(
MIE_EMITTER(rewriter), MIE_ARITH_CMPI_UGE, iv_next, ub, "stop");
MIE_VECTOR_DEFINE(struct mie_register *, true_args);
MIE_VECTOR_DEFINE(struct mie_register *, false_args);
mie_vector_push_back(false_args, &iv_next, NULL);
for (size_t i = 1; i < MIE_VECTOR_COUNT(for_cond->b_params); i++) {
struct mie_register *param = &for_cond->b_params.items[i];
mie_vector_push_back(true_args, &param, NULL);
mie_vector_push_back(false_args, &param, NULL);
}
mie_cf_br_cond_put(
MIE_EMITTER(rewriter), iv_cmp, end_block, true_args.items,
true_args.count, for_entry, false_args.items, false_args.count);
mie_vector_destroy(true_args, NULL);
mie_vector_destroy(false_args, NULL);
mie_rewriter_erase_op(rewriter, op);
return MIE_REWRITE_RESULT(MIE_REWRITE_SUCCESS, MIE_SUCCESS);
}
MIE_REWRITE_PATTERN_BEGIN(if_pattern)
MIE_REWRITE_PATTERN_ROOT("scf", "if");
MIE_REWRITE_PATTERN_MATCH(if_match);
MIE_REWRITE_PATTERN_REWRITE(if_rewrite);
MIE_REWRITE_PATTERN_END()
MIE_REWRITE_PATTERN_BEGIN(for_pattern)
MIE_REWRITE_PATTERN_ROOT("scf", "for");
MIE_REWRITE_PATTERN_MATCH(for_match);
MIE_REWRITE_PATTERN_REWRITE(for_rewrite);
MIE_REWRITE_PATTERN_END()
static struct mie_pass_result transform(
struct mie_pass *pass, struct mie_op *op, struct mie_pass_args *args)
{
@@ -115,9 +248,11 @@ static struct mie_pass_result transform(
struct mie_convert_config *cfg = mie_convert_config_create(args->p_ctx);
mie_convert_config_add_illegal_op(cfg, "scf", "if");
mie_convert_config_add_illegal_op(cfg, "scf", "for");
struct mie_pattern_set patterns = {};
if_pattern_create(&patterns);
for_pattern_create(&patterns);
mie_convert_apply(op, cfg, &patterns);
mie_pattern_set_cleanup(&patterns);