Fixed nested procedure problem in PL/0.

This commit is contained in:
yhirose 2017-08-17 01:37:31 -04:00
parent 925a611ec0
commit 463ed17967

View File

@ -79,9 +79,17 @@ struct SymbolScope;
struct Annotation { struct Annotation {
shared_ptr<SymbolScope> scope; shared_ptr<SymbolScope> scope;
shared_ptr<vector<string>> freeVariables;
}; };
typedef AstBase<Annotation> AstPL0; typedef AstBase<Annotation> AstPL0;
shared_ptr<SymbolScope> get_closest_scope(shared_ptr<AstPL0> ast) {
ast = ast->parent;
while (ast->tag != "block"_) {
ast = ast->parent;
}
return ast->scope;
}
/* /*
* Symbol Table * Symbol Table
@ -89,29 +97,38 @@ typedef AstBase<Annotation> AstPL0;
struct SymbolScope { struct SymbolScope {
SymbolScope(shared_ptr<SymbolScope> outer) : outer(outer) {} SymbolScope(shared_ptr<SymbolScope> outer) : outer(outer) {}
bool has_symbol(const string& ident) const { bool has_symbol(const string& ident, bool extend = true) const {
auto ret = constants.count(ident) || variables.count(ident); auto ret = constants.count(ident) || variables.count(ident);
return ret ? true : (outer ? outer->has_symbol(ident) : false); return ret ? true : (extend && outer ? outer->has_symbol(ident) : false);
} }
bool has_constant(const string& ident) const { bool has_constant(const string& ident, bool extend = true) const {
return constants.count(ident) ? true : (outer ? outer->has_constant(ident) return constants.count(ident)
: false); ? true
: (extend && outer ? outer->has_constant(ident) : false);
} }
bool has_variable(const string& ident) const { bool has_variable(const string& ident, bool extend = true) const {
return variables.count(ident) ? true : (outer ? outer->has_variable(ident) return variables.count(ident)
: false); ? true
: (extend && outer ? outer->has_variable(ident) : false);
} }
bool has_procedure(const string& ident) const { bool has_procedure(const string& ident, bool extend = true) const {
return procedures.count(ident) ? true : (outer ? outer->has_procedure(ident) return procedures.count(ident)
: false); ? true
: (extend && outer ? outer->has_procedure(ident) : false);
}
shared_ptr<AstPL0> get_procedure(const string& ident) const {
auto it = procedures.find(ident);
return it != procedures.end() ? it->second : outer->get_procedure(ident);
} }
map<string, int> constants; map<string, int> constants;
set<string> variables; set<string> variables;
map<string, shared_ptr<AstPL0>> procedures; map<string, shared_ptr<AstPL0>> procedures;
set<string> free_variables;
private: private:
shared_ptr<SymbolScope> outer; shared_ptr<SymbolScope> outer;
@ -161,8 +178,7 @@ struct SymbolTable {
static void constants(const shared_ptr<AstPL0> ast, static void constants(const shared_ptr<AstPL0> ast,
shared_ptr<SymbolScope> scope) { shared_ptr<SymbolScope> scope) {
// const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' // const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' _)?
// _)?
const auto& nodes = ast->nodes; const auto& nodes = ast->nodes;
for (auto i = 0u; i < nodes.size(); i += 2) { for (auto i = 0u; i < nodes.size(); i += 2) {
const auto& ident = nodes[i + 0]->token; const auto& ident = nodes[i + 0]->token;
@ -210,6 +226,12 @@ struct SymbolTable {
throw_runtime_error(ast->nodes[0], throw_runtime_error(ast->nodes[0],
"undefined variable '" + ident + "'..."); "undefined variable '" + ident + "'...");
} }
build_on_ast(ast->nodes[1], scope);
if (!scope->has_symbol(ident, false)) {
scope->free_variables.emplace(ident);
}
} }
static void call(const shared_ptr<AstPL0> ast, static void call(const shared_ptr<AstPL0> ast,
@ -220,6 +242,15 @@ struct SymbolTable {
throw_runtime_error(ast->nodes[0], throw_runtime_error(ast->nodes[0],
"undefined procedure '" + ident + "'..."); "undefined procedure '" + ident + "'...");
} }
auto block = scope->get_procedure(ident);
if (block->scope) {
for (const auto& free : block->scope->free_variables) {
if (!scope->has_symbol(free, false)) {
scope->free_variables.emplace(free);
}
}
}
} }
static void ident(const shared_ptr<AstPL0> ast, static void ident(const shared_ptr<AstPL0> ast,
@ -228,6 +259,10 @@ struct SymbolTable {
if (!scope->has_symbol(ident)) { if (!scope->has_symbol(ident)) {
throw_runtime_error(ast, "undefined variable '" + ident + "'..."); throw_runtime_error(ast, "undefined variable '" + ident + "'...");
} }
if (!scope->has_symbol(ident, false)) {
scope->free_variables.emplace(ident);
}
} }
}; };
@ -260,9 +295,7 @@ struct Environment {
} }
shared_ptr<AstPL0> get_procedure(const string& ident) const { shared_ptr<AstPL0> get_procedure(const string& ident) const {
auto it = scope->procedures.find(ident); return scope->get_procedure(ident);
return it != scope->procedures.end() ? it->second
: outer->get_procedure(ident);
} }
private: private:
@ -602,46 +635,33 @@ struct LLVM {
{ {
auto BB = BasicBlock::Create(context_, "entry", fn); auto BB = BasicBlock::Create(context_, "entry", fn);
builder_.SetInsertPoint(BB); builder_.SetInsertPoint(BB);
compile_block(ast->nodes[0], true); compile_block(ast->nodes[0]);
builder_.CreateRetVoid(); builder_.CreateRetVoid();
} }
} }
void compile_block(const shared_ptr<AstPL0> ast, bool top) { void compile_block(const shared_ptr<AstPL0> ast) {
compile_const(ast->nodes[0], top); compile_const(ast->nodes[0]);
compile_var(ast->nodes[1], top); compile_var(ast->nodes[1]);
compile_procedure(ast->nodes[2]); compile_procedure(ast->nodes[2]);
compile_statement(ast->nodes[3]); compile_statement(ast->nodes[3]);
} }
void compile_const(const shared_ptr<AstPL0> ast, bool top) { void compile_const(const shared_ptr<AstPL0> ast) {
for (auto i = 0u; i < ast->nodes.size(); i += 2) { for (auto i = 0u; i < ast->nodes.size(); i += 2) {
auto ident = ast->nodes[i]->token; auto ident = ast->nodes[i]->token;
auto number = stoi(ast->nodes[i + 1]->token); auto number = stoi(ast->nodes[i + 1]->token);
if (top) { auto alloca =
auto gv = cast<GlobalVariable>( builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
module_->getOrInsertGlobal(ident, builder_.getInt32Ty())); builder_.CreateStore(builder_.getInt32(number), alloca);
gv->setAlignment(4);
gv->setInitializer(builder_.getInt32(number));
} else {
auto alloca =
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
builder_.CreateStore(builder_.getInt32(number), alloca);
}
} }
} }
void compile_var(const shared_ptr<AstPL0> ast, bool top) { void compile_var(const shared_ptr<AstPL0> ast) {
for (const auto node : ast->nodes) { for (const auto node : ast->nodes) {
if (top) { auto ident = node->token;
auto gv = cast<GlobalVariable>( builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
module_->getOrInsertGlobal(node->token, builder_.getInt32Ty()));
gv->setAlignment(4);
gv->setInitializer(builder_.getInt32(0));
} else {
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, node->token);
}
} }
} }
@ -650,13 +670,24 @@ struct LLVM {
auto ident = ast->nodes[i]->token; auto ident = ast->nodes[i]->token;
auto block = ast->nodes[i + 1]; auto block = ast->nodes[i + 1];
auto fn = cast<Function>( std::vector<Type*> pt(block->scope->free_variables.size(),
module_->getOrInsertFunction(ident, builder_.getVoidTy(), nullptr)); Type::getInt32PtrTy(context_));
auto ft = FunctionType::get(builder_.getVoidTy(), pt, false);
auto fn = cast<Function>(module_->getOrInsertFunction(ident, ft));
{
auto it = block->scope->free_variables.begin();
for (auto& arg : fn->args()) {
arg.setName(*it);
++it;
}
}
{ {
auto prevBB = builder_.GetInsertBlock(); auto prevBB = builder_.GetInsertBlock();
auto BB = BasicBlock::Create(context_, "entry", fn); auto BB = BasicBlock::Create(context_, "entry", fn);
builder_.SetInsertPoint(BB); builder_.SetInsertPoint(BB);
compile_block(block, false); compile_block(block);
builder_.CreateRetVoid(); builder_.CreateRetVoid();
builder_.SetInsertPoint(prevBB); builder_.SetInsertPoint(prevBB);
} }
@ -670,13 +701,13 @@ struct LLVM {
} }
void compile_assignment(const shared_ptr<AstPL0> ast) { void compile_assignment(const shared_ptr<AstPL0> ast) {
auto name = ast->nodes[0]->token; auto ident = ast->nodes[0]->token;
auto fn = builder_.GetInsertBlock()->getParent(); auto fn = builder_.GetInsertBlock()->getParent();
auto tbl = fn->getValueSymbolTable(); auto tbl = fn->getValueSymbolTable();
auto var = tbl->lookup(name); auto var = tbl->lookup(ident);
if (!var) { if (!var) {
var = module_->getGlobalVariable(name); throw_runtime_error(ast, "'" + ident + "' is not defined...");
} }
auto val = compile_expression(ast->nodes[1]); auto val = compile_expression(ast->nodes[1]);
@ -684,8 +715,24 @@ struct LLVM {
} }
void compile_call(const shared_ptr<AstPL0> ast) { void compile_call(const shared_ptr<AstPL0> ast) {
auto fn = module_->getFunction(ast->nodes[0]->token); auto ident = ast->nodes[0]->token;
builder_.CreateCall(fn);
auto scope = get_closest_scope(ast);
auto block = scope->get_procedure(ident);
std::vector<Value*> args;
for (auto& free : block->scope->free_variables) {
auto fn = builder_.GetInsertBlock()->getParent();
auto tbl = fn->getValueSymbolTable();
auto var = tbl->lookup(free);
if (!var) {
throw_runtime_error(ast, "'" + free + "' is not defined...");
}
args.push_back(var);
}
auto fn = module_->getFunction(ident);
builder_.CreateCall(fn, args);
} }
void compile_statements(const shared_ptr<AstPL0> ast) { void compile_statements(const shared_ptr<AstPL0> ast) {
@ -831,13 +878,13 @@ struct LLVM {
} }
Value* compile_ident(const shared_ptr<AstPL0> ast) { Value* compile_ident(const shared_ptr<AstPL0> ast) {
auto name = ast->token; auto ident = ast->token;
auto fn = builder_.GetInsertBlock()->getParent(); auto fn = builder_.GetInsertBlock()->getParent();
auto tbl = fn->getValueSymbolTable(); auto tbl = fn->getValueSymbolTable();
auto var = tbl->lookup(name); auto var = tbl->lookup(ident);
if (!var) { if (!var) {
var = module_->getGlobalVariable(name); throw_runtime_error(ast, "'" + ident + "' is not defined...");
} }
return builder_.CreateLoad(var); return builder_.CreateLoad(var);
@ -900,13 +947,13 @@ int main(int argc, const char** argv) {
} }
} }
if (opt_ast) {
cout << ast_to_s(ast);
}
try { try {
SymbolTable::build_on_ast(ast); SymbolTable::build_on_ast(ast);
if (opt_ast) {
cout << ast_to_s<AstPL0>(ast);
}
if (opt_llvm || opt_jit) { if (opt_llvm || opt_jit) {
if (opt_llvm) { if (opt_llvm) {
LLVM::dump(ast); LLVM::dump(ast);