Fixed nested procedure problem in PL/0.

This commit is contained in:
yhirose 2017-08-17 01:37:31 -04:00
parent 925a611ec0
commit 463ed17967

View File

@ -79,9 +79,17 @@ struct SymbolScope;
struct Annotation {
shared_ptr<SymbolScope> scope;
shared_ptr<vector<string>> freeVariables;
};
typedef AstBase<Annotation> AstPL0;
shared_ptr<SymbolScope> get_closest_scope(shared_ptr<AstPL0> ast) {
ast = ast->parent;
while (ast->tag != "block"_) {
ast = ast->parent;
}
return ast->scope;
}
/*
* Symbol Table
@ -89,29 +97,38 @@ typedef AstBase<Annotation> AstPL0;
struct SymbolScope {
SymbolScope(shared_ptr<SymbolScope> outer) : outer(outer) {}
bool has_symbol(const string& ident) const {
bool has_symbol(const string& ident, bool extend = true) const {
auto ret = constants.count(ident) || variables.count(ident);
return ret ? true : (outer ? outer->has_symbol(ident) : false);
return ret ? true : (extend && outer ? outer->has_symbol(ident) : false);
}
bool has_constant(const string& ident) const {
return constants.count(ident) ? true : (outer ? outer->has_constant(ident)
: false);
bool has_constant(const string& ident, bool extend = true) const {
return constants.count(ident)
? true
: (extend && outer ? outer->has_constant(ident) : false);
}
bool has_variable(const string& ident) const {
return variables.count(ident) ? true : (outer ? outer->has_variable(ident)
: false);
bool has_variable(const string& ident, bool extend = true) const {
return variables.count(ident)
? true
: (extend && outer ? outer->has_variable(ident) : false);
}
bool has_procedure(const string& ident) const {
return procedures.count(ident) ? true : (outer ? outer->has_procedure(ident)
: false);
bool has_procedure(const string& ident, bool extend = true) const {
return procedures.count(ident)
? true
: (extend && outer ? outer->has_procedure(ident) : false);
}
shared_ptr<AstPL0> get_procedure(const string& ident) const {
auto it = procedures.find(ident);
return it != procedures.end() ? it->second : outer->get_procedure(ident);
}
map<string, int> constants;
set<string> variables;
map<string, shared_ptr<AstPL0>> procedures;
set<string> free_variables;
private:
shared_ptr<SymbolScope> outer;
@ -161,8 +178,7 @@ struct SymbolTable {
static void constants(const shared_ptr<AstPL0> ast,
shared_ptr<SymbolScope> scope) {
// const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';'
// _)?
// const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' _)?
const auto& nodes = ast->nodes;
for (auto i = 0u; i < nodes.size(); i += 2) {
const auto& ident = nodes[i + 0]->token;
@ -210,6 +226,12 @@ struct SymbolTable {
throw_runtime_error(ast->nodes[0],
"undefined variable '" + ident + "'...");
}
build_on_ast(ast->nodes[1], scope);
if (!scope->has_symbol(ident, false)) {
scope->free_variables.emplace(ident);
}
}
static void call(const shared_ptr<AstPL0> ast,
@ -220,6 +242,15 @@ struct SymbolTable {
throw_runtime_error(ast->nodes[0],
"undefined procedure '" + ident + "'...");
}
auto block = scope->get_procedure(ident);
if (block->scope) {
for (const auto& free : block->scope->free_variables) {
if (!scope->has_symbol(free, false)) {
scope->free_variables.emplace(free);
}
}
}
}
static void ident(const shared_ptr<AstPL0> ast,
@ -228,6 +259,10 @@ struct SymbolTable {
if (!scope->has_symbol(ident)) {
throw_runtime_error(ast, "undefined variable '" + ident + "'...");
}
if (!scope->has_symbol(ident, false)) {
scope->free_variables.emplace(ident);
}
}
};
@ -260,9 +295,7 @@ struct Environment {
}
shared_ptr<AstPL0> get_procedure(const string& ident) const {
auto it = scope->procedures.find(ident);
return it != scope->procedures.end() ? it->second
: outer->get_procedure(ident);
return scope->get_procedure(ident);
}
private:
@ -602,46 +635,33 @@ struct LLVM {
{
auto BB = BasicBlock::Create(context_, "entry", fn);
builder_.SetInsertPoint(BB);
compile_block(ast->nodes[0], true);
compile_block(ast->nodes[0]);
builder_.CreateRetVoid();
}
}
void compile_block(const shared_ptr<AstPL0> ast, bool top) {
compile_const(ast->nodes[0], top);
compile_var(ast->nodes[1], top);
void compile_block(const shared_ptr<AstPL0> ast) {
compile_const(ast->nodes[0]);
compile_var(ast->nodes[1]);
compile_procedure(ast->nodes[2]);
compile_statement(ast->nodes[3]);
}
void compile_const(const shared_ptr<AstPL0> ast, bool top) {
void compile_const(const shared_ptr<AstPL0> ast) {
for (auto i = 0u; i < ast->nodes.size(); i += 2) {
auto ident = ast->nodes[i]->token;
auto number = stoi(ast->nodes[i + 1]->token);
if (top) {
auto gv = cast<GlobalVariable>(
module_->getOrInsertGlobal(ident, builder_.getInt32Ty()));
gv->setAlignment(4);
gv->setInitializer(builder_.getInt32(number));
} else {
auto alloca =
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
builder_.CreateStore(builder_.getInt32(number), alloca);
}
auto alloca =
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
builder_.CreateStore(builder_.getInt32(number), alloca);
}
}
void compile_var(const shared_ptr<AstPL0> ast, bool top) {
void compile_var(const shared_ptr<AstPL0> ast) {
for (const auto node : ast->nodes) {
if (top) {
auto gv = cast<GlobalVariable>(
module_->getOrInsertGlobal(node->token, builder_.getInt32Ty()));
gv->setAlignment(4);
gv->setInitializer(builder_.getInt32(0));
} else {
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, node->token);
}
auto ident = node->token;
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
}
}
@ -650,13 +670,24 @@ struct LLVM {
auto ident = ast->nodes[i]->token;
auto block = ast->nodes[i + 1];
auto fn = cast<Function>(
module_->getOrInsertFunction(ident, builder_.getVoidTy(), nullptr));
std::vector<Type*> pt(block->scope->free_variables.size(),
Type::getInt32PtrTy(context_));
auto ft = FunctionType::get(builder_.getVoidTy(), pt, false);
auto fn = cast<Function>(module_->getOrInsertFunction(ident, ft));
{
auto it = block->scope->free_variables.begin();
for (auto& arg : fn->args()) {
arg.setName(*it);
++it;
}
}
{
auto prevBB = builder_.GetInsertBlock();
auto BB = BasicBlock::Create(context_, "entry", fn);
builder_.SetInsertPoint(BB);
compile_block(block, false);
compile_block(block);
builder_.CreateRetVoid();
builder_.SetInsertPoint(prevBB);
}
@ -670,13 +701,13 @@ struct LLVM {
}
void compile_assignment(const shared_ptr<AstPL0> ast) {
auto name = ast->nodes[0]->token;
auto ident = ast->nodes[0]->token;
auto fn = builder_.GetInsertBlock()->getParent();
auto tbl = fn->getValueSymbolTable();
auto var = tbl->lookup(name);
auto var = tbl->lookup(ident);
if (!var) {
var = module_->getGlobalVariable(name);
throw_runtime_error(ast, "'" + ident + "' is not defined...");
}
auto val = compile_expression(ast->nodes[1]);
@ -684,8 +715,24 @@ struct LLVM {
}
void compile_call(const shared_ptr<AstPL0> ast) {
auto fn = module_->getFunction(ast->nodes[0]->token);
builder_.CreateCall(fn);
auto ident = ast->nodes[0]->token;
auto scope = get_closest_scope(ast);
auto block = scope->get_procedure(ident);
std::vector<Value*> args;
for (auto& free : block->scope->free_variables) {
auto fn = builder_.GetInsertBlock()->getParent();
auto tbl = fn->getValueSymbolTable();
auto var = tbl->lookup(free);
if (!var) {
throw_runtime_error(ast, "'" + free + "' is not defined...");
}
args.push_back(var);
}
auto fn = module_->getFunction(ident);
builder_.CreateCall(fn, args);
}
void compile_statements(const shared_ptr<AstPL0> ast) {
@ -831,13 +878,13 @@ struct LLVM {
}
Value* compile_ident(const shared_ptr<AstPL0> ast) {
auto name = ast->token;
auto ident = ast->token;
auto fn = builder_.GetInsertBlock()->getParent();
auto tbl = fn->getValueSymbolTable();
auto var = tbl->lookup(name);
auto var = tbl->lookup(ident);
if (!var) {
var = module_->getGlobalVariable(name);
throw_runtime_error(ast, "'" + ident + "' is not defined...");
}
return builder_.CreateLoad(var);
@ -900,13 +947,13 @@ int main(int argc, const char** argv) {
}
}
if (opt_ast) {
cout << ast_to_s(ast);
}
try {
SymbolTable::build_on_ast(ast);
if (opt_ast) {
cout << ast_to_s<AstPL0>(ast);
}
if (opt_llvm || opt_jit) {
if (opt_llvm) {
LLVM::dump(ast);