diff --git a/mie/dialect/builtin/builtin.c b/mie/dialect/builtin/builtin.c index 5d24f53..f81ea7a 100644 --- a/mie/dialect/builtin/builtin.c +++ b/mie/dialect/builtin/builtin.c @@ -97,6 +97,7 @@ MIE_DIALECT_BEGIN(mie_builtin, struct builtin_dialect, "builtin") MIE_DIALECT_ADD_TYPE(mie_builtin_float); MIE_DIALECT_ADD_TYPE(mie_builtin_index); MIE_DIALECT_ADD_TYPE(mie_builtin_string); + MIE_DIALECT_ADD_TYPE(mie_builtin_memref); MIE_DIALECT_ADD_ATTRIBUTE(mie_builtin_int); MIE_DIALECT_ADD_ATTRIBUTE(mie_builtin_float); MIE_DIALECT_ADD_ATTRIBUTE(mie_builtin_type); diff --git a/mie/dialect/builtin/type/memref.c b/mie/dialect/builtin/type/memref.c new file mode 100644 index 0000000..3ddcbbb --- /dev/null +++ b/mie/dialect/builtin/type/memref.c @@ -0,0 +1,290 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct memref_type { + struct mie_type m_base; + MIE_VECTOR_DECLARE(struct mie_memref_rank, m_rank); +}; + +struct mie_type *mie_ctx_get_memref_type( + struct mie_ctx *ctx, const struct mie_memref_rank *ranks, size_t nr_ranks) +{ + struct mie_type_definition *type_info + = mie_ctx_get_type_definition(ctx, "builtin", "memref"); + if (!type_info) { + return NULL; + } + + struct mie_id_builder id_ctx; + mie_id_builder_begin(&id_ctx, mie_id_map_get_ns(&ctx->ctx_types)); + mie_id_builder_add_cstr(&id_ctx, "memref"); + + for (size_t i = 0; i < nr_ranks; i++) { + switch (ranks[i].r_ranktype) { + case MIE_MEMREF_RANK_UNKNOWN: + mie_id_builder_add_char(&id_ctx, '?'); + break; + case MIE_MEMREF_RANK_TYPE: + mie_id_builder_add_id(&id_ctx, &ranks[i].r_type->ty_id); + break; + case MIE_MEMREF_RANK_STATIC: + mie_id_builder_add_int(&id_ctx, ranks[i].r_static); + break; + default: + return NULL; + } + } + + mie_id id; + mie_id_builder_end(&id_ctx, &id); + + mie_id *target = mie_id_map_get(&ctx->ctx_types, &id); + if (target) { + return b_unbox(struct mie_type, target, ty_id); + } + + struct memref_type *type + = (struct memref_type *)mie_type_create(type_info); + if (!type) { + return NULL; + } + + b_bstr type_name; + b_bstr_begin_dynamic(&type_name); + b_bstr_write_cstr(&type_name, "builtin.memref<", NULL); + + struct mie_printer printer; + mie_printer_init( + &printer, ctx, (b_stream *)&type_name, MIE_PRINT_F_ABBREVIATED); + + for (size_t i = 0; i < nr_ranks; i++) { + mie_vector_push_back(type->m_rank, &ranks[i], NULL); + + if (i > 0) { + b_bstr_write_char(&type_name, '*'); + } + + switch (ranks[i].r_ranktype) { + case MIE_MEMREF_RANK_UNKNOWN: + b_bstr_write_char(&type_name, '?'); + break; + case MIE_MEMREF_RANK_TYPE: + mie_printer_print_type(&printer, ranks[i].r_type); + break; + case MIE_MEMREF_RANK_STATIC: + b_bstr_write_fmt(&type_name, NULL, "%zu", ranks[i].r_static); + break; + default: + break; + } + } + + b_bstr_write_char(&type_name, '>'); + type->m_base.ty_name = b_bstr_end(&type_name); + type->m_base.ty_instance_size = 0; + b_rope name_rope = B_ROPE_CSTR(type->m_base.ty_name); + + mie_id_map_put(&ctx->ctx_types, &type->m_base.ty_id, &name_rope); + return (struct mie_type *)type; +} + +static void type_init( + const struct mie_type_definition *type_info, struct mie_type *type) +{ +} + +static enum mie_status print(const struct mie_type *ty, struct mie_printer *out) +{ + const struct memref_type *memref_ty = (const struct memref_type *)ty; + b_stream_write_string( + out->p_stream, + (out->p_flags & MIE_PRINT_F_ABBREVIATED) ? "memref" + : "builtin.memref", + NULL); + + b_stream_write_char(out->p_stream, '<'); + for (size_t i = 0; i < MIE_VECTOR_COUNT(memref_ty->m_rank); i++) { + const struct mie_memref_rank *rank = &memref_ty->m_rank.items[i]; + + if (i > 0) { + b_stream_write_char(out->p_stream, '*'); + } + + switch (rank->r_ranktype) { + case MIE_MEMREF_RANK_UNKNOWN: + b_stream_write_char(out->p_stream, '?'); + break; + case MIE_MEMREF_RANK_TYPE: + mie_printer_print_type(out, rank->r_type); + break; + case MIE_MEMREF_RANK_STATIC: + b_stream_write_fmt( + out->p_stream, NULL, "%zu", rank->r_static); + break; + default: + break; + } + } + b_stream_write_char(out->p_stream, '>'); + + return MIE_SUCCESS; +} + +static bool parse_rank(struct mie_parser *parser, struct mie_memref_rank *out) +{ + enum mie_token_type tok_type = mie_parser_peek_type(parser); + struct mie_parser_item expected_tokens[] = { + MIE_PARSE_ITEM_TOKEN(MIE_TOK_TYPENAME), + MIE_PARSE_ITEM_TOKEN(MIE_TOK_INT), + MIE_PARSE_ITEM_TOKEN(MIE_SYM_QUESTION), + MIE_PARSE_ITEM_NONE, + }; + + long long v; + + switch (tok_type) { + case MIE_TOK_INT: + out->r_ranktype = MIE_MEMREF_RANK_STATIC; + out->r_type = mie_ctx_get_type( + mie_parser_get_mie_ctx(parser), "builtin", "index"); + mie_parser_parse_int(parser, &v, &out->r_span); + out->r_static = v; + break; + case MIE_TOK_SYMBOL: + switch (mie_parser_peek_symbol(parser)) { + case MIE_SYM_QUESTION: + out->r_span = mie_parser_peek(parser)->tok_location; + mie_parser_parse_symbol(parser, MIE_SYM_QUESTION); + out->r_ranktype = MIE_MEMREF_RANK_UNKNOWN; + break; + case MIE_SYM_LEFT_PAREN: + out->r_ranktype = MIE_MEMREF_RANK_TYPE; + if (!mie_parser_parse_type( + parser, "memref type rank", &out->r_type, + &out->r_span)) { + return false; + } + break; + default: + mie_parser_report_unexpected_token_v( + parser, expected_tokens, "memref rank"); + return false; + } + break; + case MIE_TOK_WORD: + case MIE_TOK_TYPENAME: + out->r_ranktype = MIE_MEMREF_RANK_TYPE; + if (!mie_parser_parse_type( + parser, "memref type rank", &out->r_type, &out->r_span)) { + return false; + } + break; + default: + mie_parser_report_unexpected_token_v( + parser, expected_tokens, "memref rank"); + return false; + } + return true; +} + +static bool parse_ranks( + struct mie_parser *parser, + MIE_VECTOR_REF_PARAM(struct mie_memref_rank, ranks)) +{ + struct mie_memref_rank rank; + if (!parse_rank(parser, &rank)) { + return false; + } + + mie_vector_ref_push_back(ranks, &rank, NULL); + + while (1) { + if (mie_parser_peek_symbol(parser) == MIE_SYM_RIGHT_ANGLE) { + break; + } + + if (!mie_parser_parse_symbol(parser, MIE_SYM_ASTERISK)) { + mie_parser_report_unexpected_token( + parser, MIE_SYM_ASTERISK, "memref rank list"); + return false; + } + + if (!parse_rank(parser, &rank)) { + return false; + } + + mie_vector_ref_push_back(ranks, &rank, NULL); + } + + return true; +} + +static enum mie_status parse(struct mie_parser *parser, const struct mie_type **out) +{ + if (!mie_parser_parse_symbol(parser, MIE_SYM_LEFT_ANGLE)) { + mie_parser_report_unexpected_token( + parser, MIE_SYM_LEFT_ANGLE, "memref type"); + return MIE_ERR_BAD_SYNTAX; + } + + MIE_VECTOR_DEFINE(struct mie_memref_rank, ranks); + if (!parse_ranks(parser, MIE_VECTOR_REF(ranks))) { + return MIE_ERR_BAD_SYNTAX; + } + + if (!mie_parser_parse_symbol(parser, MIE_SYM_RIGHT_ANGLE)) { + mie_parser_report_unexpected_token( + parser, MIE_SYM_RIGHT_ANGLE, "memref type"); + return MIE_ERR_BAD_SYNTAX; + } + + const struct mie_type *type = mie_ctx_get_memref_type( + mie_parser_get_mie_ctx(parser), ranks.items, ranks.count); + mie_vector_destroy(ranks, NULL); + + *out = type; + return MIE_SUCCESS; +} + +size_t mie_memref_type_get_nr_ranks(const struct mie_type *type) +{ + if (!mie_type_is(type, "builtin", "memref")) { + return 0; + } + + const struct memref_type *memref = (const struct memref_type *)type; + return MIE_VECTOR_COUNT(memref->m_rank); +} + +const struct mie_memref_rank *mie_memref_type_get_rank( + const struct mie_type *type, size_t i) +{ + if (!mie_type_is(type, "builtin", "memref")) { + return 0; + } + + const struct memref_type *memref = (const struct memref_type *)type; + if (i >= MIE_VECTOR_COUNT(memref->m_rank)) { + return NULL; + } + + return &memref->m_rank.items[i]; +} + +const struct mie_memref_rank *mie_memref_type_get_rank( + const struct mie_type *type, size_t i); + +MIE_TYPE_DEFINITION_BEGIN(mie_builtin_memref, "memref") + MIE_TYPE_DEFINITION_INIT(type_init); + MIE_TYPE_DEFINITION_STRUCT(struct memref_type); + MIE_TYPE_DEFINITION_PRINT(print); + MIE_TYPE_DEFINITION_PARSE(parse); +MIE_TYPE_DEFINITION_END() diff --git a/mie/dialect/memref/memref.c b/mie/dialect/memref/memref.c index 6d734cd..62beec6 100644 --- a/mie/dialect/memref/memref.c +++ b/mie/dialect/memref/memref.c @@ -6,6 +6,5 @@ #include MIE_DIALECT_BEGIN(mie_memref, struct mie_dialect, "memref") - MIE_DIALECT_ADD_TYPE(mie_memref_memref); MIE_DIALECT_ADD_OP(mie_memref_load); MIE_DIALECT_END() diff --git a/mie/dialect/memref/type/memref.c b/mie/dialect/memref/type/memref.c deleted file mode 100644 index 9386232..0000000 --- a/mie/dialect/memref/type/memref.c +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include -#include -#include -#include -#include - -enum memref_rank_type { - MEMREF_RANK_UNKNOWN = 0, - MEMREF_RANK_STATIC, - MEMREF_RANK_TYPE, -}; - -struct memref_rank { - enum memref_rank_type r_ranktype; - union { - size_t r_static; - const struct mie_type *r_type; - }; -}; - -struct memref_type { - struct mie_type m_base; - MIE_VECTOR_DECLARE(struct memref_rank, m_rank); -}; - -static void type_init( - const struct mie_type_definition *type_info, struct mie_type *type) -{ -} - -static enum mie_status print(const struct mie_type *ty, struct mie_printer *out) -{ - b_stream_write_string( - out->p_stream, - (out->p_flags & MIE_PRINT_F_ABBREVIATED) ? "memref" : "memref.memref", - NULL); - return MIE_SUCCESS; -} - -static enum mie_status parse(struct mie_parser *parser, const struct mie_type **out) -{ - printf("Parse memref!\n"); - - return MIE_ERR_BAD_FORMAT; -} - -MIE_TYPE_DEFINITION_BEGIN(mie_memref_memref, "memref") - MIE_TYPE_DEFINITION_INIT(type_init); - MIE_TYPE_DEFINITION_PRINT(print); - MIE_TYPE_DEFINITION_PARSE(parse); -MIE_TYPE_DEFINITION_END() diff --git a/mie/parse/parser.c b/mie/parse/parser.c index 57aa0fb..cab3ddd 100644 --- a/mie/parse/parser.c +++ b/mie/parse/parser.c @@ -319,7 +319,7 @@ static bool parse_builtin_type_name( if (!strcmp(name_cstr, "memref")) { type_info = mie_ctx_get_type_definition( - ctx->p_ctx, "memref", "memref"); + ctx->p_ctx, "builtin", "memref"); } else if (!strcmp(name_cstr, "index")) { type_info = mie_ctx_get_type_definition( ctx->p_ctx, "builtin", "index");