diff --git a/pl0/pl0.cc b/pl0/pl0.cc index 7fdf92c..3d4105e 100644 --- a/pl0/pl0.cc +++ b/pl0/pl0.cc @@ -79,9 +79,17 @@ struct SymbolScope; struct Annotation { shared_ptr scope; + shared_ptr> freeVariables; }; typedef AstBase AstPL0; +shared_ptr get_closest_scope(shared_ptr ast) { + ast = ast->parent; + while (ast->tag != "block"_) { + ast = ast->parent; + } + return ast->scope; +} /* * Symbol Table @@ -89,29 +97,38 @@ typedef AstBase AstPL0; struct SymbolScope { SymbolScope(shared_ptr outer) : outer(outer) {} - bool has_symbol(const string& ident) const { + bool has_symbol(const string& ident, bool extend = true) const { auto ret = constants.count(ident) || variables.count(ident); - return ret ? true : (outer ? outer->has_symbol(ident) : false); + return ret ? true : (extend && outer ? outer->has_symbol(ident) : false); } - bool has_constant(const string& ident) const { - return constants.count(ident) ? true : (outer ? outer->has_constant(ident) - : false); + bool has_constant(const string& ident, bool extend = true) const { + return constants.count(ident) + ? true + : (extend && outer ? outer->has_constant(ident) : false); } - bool has_variable(const string& ident) const { - return variables.count(ident) ? true : (outer ? outer->has_variable(ident) - : false); + bool has_variable(const string& ident, bool extend = true) const { + return variables.count(ident) + ? true + : (extend && outer ? outer->has_variable(ident) : false); } - bool has_procedure(const string& ident) const { - return procedures.count(ident) ? true : (outer ? outer->has_procedure(ident) - : false); + bool has_procedure(const string& ident, bool extend = true) const { + return procedures.count(ident) + ? true + : (extend && outer ? outer->has_procedure(ident) : false); + } + + shared_ptr get_procedure(const string& ident) const { + auto it = procedures.find(ident); + return it != procedures.end() ? it->second : outer->get_procedure(ident); } map constants; set variables; map> procedures; + set free_variables; private: shared_ptr outer; @@ -161,8 +178,7 @@ struct SymbolTable { static void constants(const shared_ptr ast, shared_ptr scope) { - // const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' - // _)? + // const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' _)? const auto& nodes = ast->nodes; for (auto i = 0u; i < nodes.size(); i += 2) { const auto& ident = nodes[i + 0]->token; @@ -210,6 +226,12 @@ struct SymbolTable { throw_runtime_error(ast->nodes[0], "undefined variable '" + ident + "'..."); } + + build_on_ast(ast->nodes[1], scope); + + if (!scope->has_symbol(ident, false)) { + scope->free_variables.emplace(ident); + } } static void call(const shared_ptr ast, @@ -220,6 +242,15 @@ struct SymbolTable { throw_runtime_error(ast->nodes[0], "undefined procedure '" + ident + "'..."); } + + auto block = scope->get_procedure(ident); + if (block->scope) { + for (const auto& free : block->scope->free_variables) { + if (!scope->has_symbol(free, false)) { + scope->free_variables.emplace(free); + } + } + } } static void ident(const shared_ptr ast, @@ -228,6 +259,10 @@ struct SymbolTable { if (!scope->has_symbol(ident)) { throw_runtime_error(ast, "undefined variable '" + ident + "'..."); } + + if (!scope->has_symbol(ident, false)) { + scope->free_variables.emplace(ident); + } } }; @@ -260,9 +295,7 @@ struct Environment { } shared_ptr get_procedure(const string& ident) const { - auto it = scope->procedures.find(ident); - return it != scope->procedures.end() ? it->second - : outer->get_procedure(ident); + return scope->get_procedure(ident); } private: @@ -602,46 +635,33 @@ struct LLVM { { auto BB = BasicBlock::Create(context_, "entry", fn); builder_.SetInsertPoint(BB); - compile_block(ast->nodes[0], true); + compile_block(ast->nodes[0]); builder_.CreateRetVoid(); } } - void compile_block(const shared_ptr ast, bool top) { - compile_const(ast->nodes[0], top); - compile_var(ast->nodes[1], top); + void compile_block(const shared_ptr ast) { + compile_const(ast->nodes[0]); + compile_var(ast->nodes[1]); compile_procedure(ast->nodes[2]); compile_statement(ast->nodes[3]); } - void compile_const(const shared_ptr ast, bool top) { + void compile_const(const shared_ptr ast) { for (auto i = 0u; i < ast->nodes.size(); i += 2) { auto ident = ast->nodes[i]->token; auto number = stoi(ast->nodes[i + 1]->token); - if (top) { - auto gv = cast( - module_->getOrInsertGlobal(ident, builder_.getInt32Ty())); - gv->setAlignment(4); - gv->setInitializer(builder_.getInt32(number)); - } else { - auto alloca = - builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident); - builder_.CreateStore(builder_.getInt32(number), alloca); - } + auto alloca = + builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident); + builder_.CreateStore(builder_.getInt32(number), alloca); } } - void compile_var(const shared_ptr ast, bool top) { + void compile_var(const shared_ptr ast) { for (const auto node : ast->nodes) { - if (top) { - auto gv = cast( - module_->getOrInsertGlobal(node->token, builder_.getInt32Ty())); - gv->setAlignment(4); - gv->setInitializer(builder_.getInt32(0)); - } else { - builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, node->token); - } + auto ident = node->token; + builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident); } } @@ -650,13 +670,24 @@ struct LLVM { auto ident = ast->nodes[i]->token; auto block = ast->nodes[i + 1]; - auto fn = cast( - module_->getOrInsertFunction(ident, builder_.getVoidTy(), nullptr)); + std::vector pt(block->scope->free_variables.size(), + Type::getInt32PtrTy(context_)); + auto ft = FunctionType::get(builder_.getVoidTy(), pt, false); + auto fn = cast(module_->getOrInsertFunction(ident, ft)); + + { + auto it = block->scope->free_variables.begin(); + for (auto& arg : fn->args()) { + arg.setName(*it); + ++it; + } + } + { auto prevBB = builder_.GetInsertBlock(); auto BB = BasicBlock::Create(context_, "entry", fn); builder_.SetInsertPoint(BB); - compile_block(block, false); + compile_block(block); builder_.CreateRetVoid(); builder_.SetInsertPoint(prevBB); } @@ -670,13 +701,13 @@ struct LLVM { } void compile_assignment(const shared_ptr ast) { - auto name = ast->nodes[0]->token; + auto ident = ast->nodes[0]->token; auto fn = builder_.GetInsertBlock()->getParent(); auto tbl = fn->getValueSymbolTable(); - auto var = tbl->lookup(name); + auto var = tbl->lookup(ident); if (!var) { - var = module_->getGlobalVariable(name); + throw_runtime_error(ast, "'" + ident + "' is not defined..."); } auto val = compile_expression(ast->nodes[1]); @@ -684,8 +715,24 @@ struct LLVM { } void compile_call(const shared_ptr ast) { - auto fn = module_->getFunction(ast->nodes[0]->token); - builder_.CreateCall(fn); + auto ident = ast->nodes[0]->token; + + auto scope = get_closest_scope(ast); + auto block = scope->get_procedure(ident); + + std::vector args; + for (auto& free : block->scope->free_variables) { + auto fn = builder_.GetInsertBlock()->getParent(); + auto tbl = fn->getValueSymbolTable(); + auto var = tbl->lookup(free); + if (!var) { + throw_runtime_error(ast, "'" + free + "' is not defined..."); + } + args.push_back(var); + } + + auto fn = module_->getFunction(ident); + builder_.CreateCall(fn, args); } void compile_statements(const shared_ptr ast) { @@ -831,13 +878,13 @@ struct LLVM { } Value* compile_ident(const shared_ptr ast) { - auto name = ast->token; + auto ident = ast->token; auto fn = builder_.GetInsertBlock()->getParent(); auto tbl = fn->getValueSymbolTable(); - auto var = tbl->lookup(name); + auto var = tbl->lookup(ident); if (!var) { - var = module_->getGlobalVariable(name); + throw_runtime_error(ast, "'" + ident + "' is not defined..."); } return builder_.CreateLoad(var); @@ -900,13 +947,13 @@ int main(int argc, const char** argv) { } } - if (opt_ast) { - cout << ast_to_s(ast); - } - try { SymbolTable::build_on_ast(ast); + if (opt_ast) { + cout << ast_to_s(ast); + } + if (opt_llvm || opt_jit) { if (opt_llvm) { LLVM::dump(ast);