diff options
Diffstat (limited to 'submit/ASTVisitor.java')
-rw-r--r-- | submit/ASTVisitor.java | 624 |
1 files changed, 331 insertions, 293 deletions
diff --git a/submit/ASTVisitor.java b/submit/ASTVisitor.java index b8762b5..5cafbb1 100644 --- a/submit/ASTVisitor.java +++ b/submit/ASTVisitor.java @@ -1,304 +1,342 @@ package submit; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.TerminalNode; import parser.CminusBaseVisitor; import parser.CminusParser; import submit.ast.*; -import java.util.ArrayList; -import java.util.List; -import java.util.logging.Logger; - public class ASTVisitor extends CminusBaseVisitor<Node> { - private final Logger LOGGER; - private SymbolTable symbolTable; - - public ASTVisitor(Logger LOGGER) { - this.LOGGER = LOGGER; - } - - public SymbolTable getSymbolTable() { - return symbolTable; - } - - private VarType getVarType(CminusParser.TypeSpecifierContext ctx) { - final String t = ctx.getText(); - return (t.equals("int")) ? VarType.INT : (t.equals("bool")) ? VarType.BOOL : VarType.CHAR; - } - - @Override - public Node visitProgram(CminusParser.ProgramContext ctx) { - symbolTable = new SymbolTable(); - List<Declaration> decls = new ArrayList<>(); - for (CminusParser.DeclarationContext d : ctx.declaration()) { - decls.add((Declaration) visitDeclaration(d)); - } - return new Program(decls); - } - - @Override - public Node visitVarDeclaration(CminusParser.VarDeclarationContext ctx) { - VarType type = getVarType(ctx.typeSpecifier()); - List<String> ids = new ArrayList<>(); - List<Integer> arraySizes = new ArrayList<>(); - for (CminusParser.VarDeclIdContext v : ctx.varDeclId()) { - String id = v.ID().getText(); - ids.add(id); - symbolTable.addSymbol(id, new SymbolInfo(id, type, false)); - if (v.NUMCONST() != null) { - arraySizes.add(Integer.parseInt(v.NUMCONST().getText())); - } else { - arraySizes.add(-1); - } - } - final boolean isStatic = false; - return new VarDeclaration(type, ids, arraySizes, isStatic); - } - - @Override - public Node visitFunDeclaration(CminusParser.FunDeclarationContext ctx) { - VarType returnType = null; - if (ctx.typeSpecifier() != null) { - returnType = getVarType(ctx.typeSpecifier()); - } - String id = ctx.ID().getText(); - List<Param> params = new ArrayList<>(); - for (CminusParser.ParamContext p : ctx.param()) { - params.add((Param) visitParam(p)); - } - Statement statement = (Statement) visitStatement(ctx.statement()); - symbolTable.addSymbol(id, new SymbolInfo(id, returnType, true)); - return new FunDeclaration(returnType, id, params, statement); - } - - @Override - public Node visitParam(CminusParser.ParamContext ctx) { - VarType type = getVarType(ctx.typeSpecifier()); - String id = ctx.paramId().ID().getText(); - symbolTable.addSymbol(id, new SymbolInfo(id, type, false)); - return new Param(type, id, ctx.paramId().children.size() > 1); - } - - @Override - public Node visitCompoundStmt(CminusParser.CompoundStmtContext ctx) { - symbolTable = symbolTable.createChild(); - List<Statement> statements = new ArrayList<>(); - for (CminusParser.VarDeclarationContext d : ctx.varDeclaration()) { - statements.add((VarDeclaration) visitVarDeclaration(d)); - } - for (CminusParser.StatementContext d : ctx.statement()) { - statements.add((Statement) visitStatement(d)); - } - symbolTable = symbolTable.getParent(); - return new CompoundStatement(statements); - } - - @Override - public Node visitExpressionStmt(CminusParser.ExpressionStmtContext ctx) { - if (ctx.expression() == null) { - return Statement.empty(); - } - return new ExpressionStatement((Expression) visitExpression(ctx.expression())); - } - - @Override - public Node visitIfStmt(CminusParser.IfStmtContext ctx) { - Expression expression = (Expression) visitSimpleExpression(ctx.simpleExpression()); - Statement trueStatement = (Statement) visitStatement(ctx.statement(0)); - Statement falseStatement = null; - if (ctx.statement().size() > 1) { - falseStatement = (Statement) visitStatement(ctx.statement(1)); - } - return new If(expression, trueStatement, falseStatement); - } - - @Override - public Node visitWhileStmt(CminusParser.WhileStmtContext ctx) { - Expression expression = (Expression) visitSimpleExpression(ctx.simpleExpression()); - Statement statement = (Statement) visitStatement(ctx.statement()); - return new While(expression, statement); - } - - @Override - public Node visitReturnStmt(CminusParser.ReturnStmtContext ctx) { - if (ctx.expression() != null) { - return new Return((Expression) visitExpression(ctx.expression())); - } - return new Return(null); - } - - @Override - public Node visitBreakStmt(CminusParser.BreakStmtContext ctx) { - return new Break(); - } - - @Override - public Node visitExpression(CminusParser.ExpressionContext ctx) { - final Node ret; - CminusParser.MutableContext mutable = ctx.mutable(); - CminusParser.ExpressionContext expression = ctx.expression(); - if (mutable != null) { - // Assignment - ParseTree operator = ctx.getChild(1); - Mutable lhs = (Mutable) visitMutable(mutable);// new Mutable(mutable.ID().getText(), (Expression) - // visitExpression(mutable.expression())); - Expression rhs = null; - if (expression != null) { - rhs = (Expression) visitExpression(expression); - } - ret = new Assignment(lhs, operator.getText(), rhs); - } else { - ret = visitSimpleExpression(ctx.simpleExpression()); - } - return ret; - } - - @Override - public Node visitOrExpression(CminusParser.OrExpressionContext ctx) { - List<Node> ands = new ArrayList<>(); - for (CminusParser.AndExpressionContext and : ctx.andExpression()) { - ands.add(visitAndExpression(and)); - } - if (ands.size() == 1) { - return ands.get(0); - } - BinaryOperator op = new BinaryOperator((Expression) ands.get(0), "||", (Expression) ands.get(1)); - for (int i = 2; i < ands.size(); ++i) { - op = new BinaryOperator(op, "||", (Expression) ands.get(i)); - } - return op; - } - - @Override - public Node visitAndExpression(CminusParser.AndExpressionContext ctx) { - List<Node> uns = new ArrayList<>(); - for (CminusParser.UnaryRelExpressionContext un : ctx.unaryRelExpression()) { - uns.add(visitUnaryRelExpression(un)); - } - if (uns.size() == 1) { - return uns.get(0); - } - BinaryOperator op = new BinaryOperator((Expression) uns.get(0), "&&", (Expression) uns.get(1)); - for (int i = 2; i < uns.size(); ++i) { - op = new BinaryOperator(op, "&&", (Expression) uns.get(i)); - } - return op; - } - - @Override - public Node visitUnaryRelExpression(CminusParser.UnaryRelExpressionContext ctx) { - Expression e = (Expression) visitRelExpression(ctx.relExpression()); - for (TerminalNode n : ctx.BANG()) { - e = new UnaryOperator("!", e); - } - return e; - } - - @Override - public Node visitRelExpression(CminusParser.RelExpressionContext ctx) { - List<Node> uns = new ArrayList<>(); - for (CminusParser.SumExpressionContext un : ctx.sumExpression()) { - uns.add(visitSumExpression(un)); - } - if (uns.size() == 1) { - return uns.get(0); - } - BinaryOperator op = new BinaryOperator((Expression) uns.get(0), ctx.relop(0).getText(), - (Expression) uns.get(1)); - for (int i = 2; i < uns.size(); ++i) { - op = new BinaryOperator(op, ctx.relop(i - 1).getText(), (Expression) uns.get(i)); - } - return op; - } - - @Override - public Node visitSumExpression(CminusParser.SumExpressionContext ctx) { - List<Node> es = new ArrayList<>(); - for (CminusParser.TermExpressionContext e : ctx.termExpression()) { - es.add(visitTermExpression(e)); - } - if (es.size() == 1) { - return es.get(0); - } - BinaryOperator op = new BinaryOperator((Expression) es.get(0), ctx.sumop(0).getText(), (Expression) es.get(1)); - for (int i = 2; i < es.size(); ++i) { - op = new BinaryOperator(op, ctx.sumop(i - 1).getText(), (Expression) es.get(i)); - } - return op; - } - - @Override - public Node visitTermExpression(CminusParser.TermExpressionContext ctx) { - List<Node> es = new ArrayList<>(); - for (CminusParser.UnaryExpressionContext e : ctx.unaryExpression()) { - es.add(visitUnaryExpression(e)); - } - if (es.size() == 1) { - return es.get(0); - } - BinaryOperator op = new BinaryOperator((Expression) es.get(0), ctx.mulop(0).getText(), (Expression) es.get(1)); - for (int i = 2; i < es.size(); ++i) { - op = new BinaryOperator(op, ctx.mulop(i - 1).getText(), (Expression) es.get(i)); - } - return op; - } - - @Override - public Node visitUnaryExpression(CminusParser.UnaryExpressionContext ctx) { - Node ret = visitFactor(ctx.factor()); - for (int i = ctx.unaryop().size() - 1; i >= 0; i--) { - ret = new UnaryOperator(ctx.unaryop(i).getText(), (Expression) ret); - } - return ret; - } - - @Override - public Node visitMutable(CminusParser.MutableContext ctx) { - Expression e = null; - if (ctx.expression() != null) { - e = (Expression) visitExpression(ctx.expression()); - } - String id = ctx.ID().getText(); - if (symbolTable.find(id) == null) { - LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + ": " + id); - } - return new Mutable(id, e); - } - - @Override - public Node visitImmutable(CminusParser.ImmutableContext ctx) { - if (ctx.expression() != null) { - return new ParenExpression((Expression) visitExpression(ctx.expression())); - } - return visitChildren(ctx); - } - - @Override - public Node visitCall(CminusParser.CallContext ctx) { - final String id = ctx.ID().getText(); - final List<Expression> args = new ArrayList<>(); - for (CminusParser.ExpressionContext e : ctx.expression()) { - args.add((Expression) visitExpression(e)); - } - if (symbolTable.find(id) == null) { - LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + ": " + id); - } - return new Call(id, args); - } - - @Override - public Node visitConstant(CminusParser.ConstantContext ctx) { - final Node node; - if (ctx.NUMCONST() != null) { - node = new NumConstant(Integer.parseInt(ctx.NUMCONST().getText())); - } else if (ctx.CHARCONST() != null) { - node = new CharConstant(ctx.CHARCONST().getText().charAt(0)); - } else if (ctx.STRINGCONST() != null) { - node = new StringConstant(ctx.STRINGCONST().getText()); - } else { - node = new BoolConstant(ctx.getText().equals("true")); - } - return node; + private final Logger LOGGER; + private SymbolTable symbolTable; + + public ASTVisitor(Logger LOGGER) { this.LOGGER = LOGGER; } + + public SymbolTable getSymbolTable() { return symbolTable; } + + private VarType getVarType(CminusParser.TypeSpecifierContext ctx) { + final String t = ctx.getText(); + return (t.equals("int")) ? VarType.INT + : (t.equals("bool")) ? VarType.BOOL + : VarType.CHAR; + } + + @Override + public Node visitProgram(CminusParser.ProgramContext ctx) { + symbolTable = new SymbolTable(); + List<Declaration> decls = new ArrayList<>(); + for (CminusParser.DeclarationContext d : ctx.declaration()) { + decls.add((Declaration)visitDeclaration(d)); + } + return new Program(decls); + } + + @Override + public Node visitVarDeclaration(CminusParser.VarDeclarationContext ctx) { + VarType type = getVarType(ctx.typeSpecifier()); + List<String> ids = new ArrayList<>(); + List<Integer> arraySizes = new ArrayList<>(); + for (CminusParser.VarDeclIdContext v : ctx.varDeclId()) { + String id = v.ID().getText(); + ids.add(id); + // symbolTable.addSymbol(id, new SymbolInfo(id, type, false)); + if (v.NUMCONST() != null) { + arraySizes.add(Integer.parseInt(v.NUMCONST().getText())); + } else { + int offset = symbolTable.addOffset(1); + + SymbolInfo symbol = new SymbolInfo(id, type, false, offset); + symbolTable.addSymbol(id, symbol); + + arraySizes.add(-1); + } + } + final boolean isStatic = false; + return new VarDeclaration(type, ids, arraySizes, isStatic); + } + + @Override + public Node visitFunDeclaration(CminusParser.FunDeclarationContext ctx) { + VarType returnType = null; + if (ctx.typeSpecifier() != null) { + returnType = getVarType(ctx.typeSpecifier()); + } + String id = ctx.ID().getText(); + symbolTable.addSymbol(id, new SymbolInfo(id, returnType, true)); + + List<Param> params = new ArrayList<>(); + + SymbolTable newTable = symbolTable.createChild(); + symbolTable = newTable; + + if (returnType != null) { + symbolTable.addSymbol("return", + new SymbolInfo("return", returnType, false, + symbolTable.addOffset(1))); + } + + for (CminusParser.ParamContext p : ctx.param()) { + params.add((Param)visitParam(p)); + } + + CompoundStatement statement = + (CompoundStatement)visitStatement(ctx.statement()); + + statement.getSymbolTable().addOtherTableBefore(newTable); + symbolTable = newTable.getParent(); + + return new FunDeclaration(returnType, id, params, statement); + } + + @Override + public Node visitParam(CminusParser.ParamContext ctx) { + VarType type = getVarType(ctx.typeSpecifier()); + String id = ctx.paramId().ID().getText(); + symbolTable.addSymbol( + id, new SymbolInfo(id, type, false, symbolTable.addOffset(1))); + + return new Param(type, id, ctx.paramId().children.size() > 1); + } + + @Override + public Node visitCompoundStmt(CminusParser.CompoundStmtContext ctx) { + SymbolTable child = symbolTable.createChild(); + symbolTable = child; + + List<Statement> statements = new ArrayList<>(); + for (CminusParser.VarDeclarationContext d : ctx.varDeclaration()) { + statements.add((VarDeclaration)visitVarDeclaration(d)); + } + for (CminusParser.StatementContext d : ctx.statement()) { + statements.add((Statement)visitStatement(d)); + } + symbolTable = child.getParent(); + + return new CompoundStatement(statements, child); + } + + @Override + public Node visitExpressionStmt(CminusParser.ExpressionStmtContext ctx) { + if (ctx.expression() == null) { + return Statement.empty(); + } + return new ExpressionStatement( + (Expression)visitExpression(ctx.expression())); + } + + @Override + public Node visitIfStmt(CminusParser.IfStmtContext ctx) { + Expression expression = + (Expression)visitSimpleExpression(ctx.simpleExpression()); + Statement trueStatement = (Statement)visitStatement(ctx.statement(0)); + Statement falseStatement = null; + if (ctx.statement().size() > 1) { + falseStatement = (Statement)visitStatement(ctx.statement(1)); + } + return new If(expression, trueStatement, falseStatement); + } + + @Override + public Node visitWhileStmt(CminusParser.WhileStmtContext ctx) { + Expression expression = + (Expression)visitSimpleExpression(ctx.simpleExpression()); + Statement statement = (Statement)visitStatement(ctx.statement()); + return new While(expression, statement); + } + + @Override + public Node visitReturnStmt(CminusParser.ReturnStmtContext ctx) { + if (ctx.expression() != null) { + return new Return((Expression)visitExpression(ctx.expression())); + } + return new Return(null); + } + + @Override + public Node visitBreakStmt(CminusParser.BreakStmtContext ctx) { + return new Break(); + } + + @Override + public Node visitExpression(CminusParser.ExpressionContext ctx) { + final Node ret; + CminusParser.MutableContext mutable = ctx.mutable(); + CminusParser.ExpressionContext expression = ctx.expression(); + if (mutable != null) { + // Assignment + ParseTree operator = ctx.getChild(1); + Mutable lhs = (Mutable)visitMutable( + mutable); // new Mutable(mutable.ID().getText(), (Expression) + // visitExpression(mutable.expression())); + Expression rhs = null; + if (expression != null) { + rhs = (Expression)visitExpression(expression); + } + ret = new Assignment(lhs, operator.getText(), rhs); + } else { + ret = visitSimpleExpression(ctx.simpleExpression()); + } + return ret; + } + + @Override + public Node visitOrExpression(CminusParser.OrExpressionContext ctx) { + List<Node> ands = new ArrayList<>(); + for (CminusParser.AndExpressionContext and : ctx.andExpression()) { + ands.add(visitAndExpression(and)); + } + if (ands.size() == 1) { + return ands.get(0); + } + BinaryOperator op = new BinaryOperator((Expression)ands.get(0), "||", + (Expression)ands.get(1)); + for (int i = 2; i < ands.size(); ++i) { + op = new BinaryOperator(op, "||", (Expression)ands.get(i)); + } + return op; + } + + @Override + public Node visitAndExpression(CminusParser.AndExpressionContext ctx) { + List<Node> uns = new ArrayList<>(); + for (CminusParser.UnaryRelExpressionContext un : ctx.unaryRelExpression()) { + uns.add(visitUnaryRelExpression(un)); + } + if (uns.size() == 1) { + return uns.get(0); + } + BinaryOperator op = new BinaryOperator((Expression)uns.get(0), "&&", + (Expression)uns.get(1)); + for (int i = 2; i < uns.size(); ++i) { + op = new BinaryOperator(op, "&&", (Expression)uns.get(i)); + } + return op; + } + + @Override + public Node + visitUnaryRelExpression(CminusParser.UnaryRelExpressionContext ctx) { + Expression e = (Expression)visitRelExpression(ctx.relExpression()); + for (TerminalNode n : ctx.BANG()) { + e = new UnaryOperator("!", e); + } + return e; + } + + @Override + public Node visitRelExpression(CminusParser.RelExpressionContext ctx) { + List<Node> uns = new ArrayList<>(); + for (CminusParser.SumExpressionContext un : ctx.sumExpression()) { + uns.add(visitSumExpression(un)); + } + if (uns.size() == 1) { + return uns.get(0); + } + BinaryOperator op = new BinaryOperator( + (Expression)uns.get(0), ctx.relop(0).getText(), (Expression)uns.get(1)); + for (int i = 2; i < uns.size(); ++i) { + op = new BinaryOperator(op, ctx.relop(i - 1).getText(), + (Expression)uns.get(i)); + } + return op; + } + + @Override + public Node visitSumExpression(CminusParser.SumExpressionContext ctx) { + List<Node> es = new ArrayList<>(); + for (CminusParser.TermExpressionContext e : ctx.termExpression()) { + es.add(visitTermExpression(e)); + } + if (es.size() == 1) { + return es.get(0); + } + BinaryOperator op = new BinaryOperator( + (Expression)es.get(0), ctx.sumop(0).getText(), (Expression)es.get(1)); + for (int i = 2; i < es.size(); ++i) { + op = new BinaryOperator(op, ctx.sumop(i - 1).getText(), + (Expression)es.get(i)); + } + return op; + } + + @Override + public Node visitTermExpression(CminusParser.TermExpressionContext ctx) { + List<Node> es = new ArrayList<>(); + for (CminusParser.UnaryExpressionContext e : ctx.unaryExpression()) { + es.add(visitUnaryExpression(e)); + } + if (es.size() == 1) { + return es.get(0); + } + BinaryOperator op = new BinaryOperator( + (Expression)es.get(0), ctx.mulop(0).getText(), (Expression)es.get(1)); + for (int i = 2; i < es.size(); ++i) { + op = new BinaryOperator(op, ctx.mulop(i - 1).getText(), + (Expression)es.get(i)); + } + return op; + } + + @Override + public Node visitUnaryExpression(CminusParser.UnaryExpressionContext ctx) { + Node ret = visitFactor(ctx.factor()); + for (int i = ctx.unaryop().size() - 1; i >= 0; i--) { + ret = new UnaryOperator(ctx.unaryop(i).getText(), (Expression)ret); + } + return ret; + } + + @Override + public Node visitMutable(CminusParser.MutableContext ctx) { + Expression e = null; + if (ctx.expression() != null) { + e = (Expression)visitExpression(ctx.expression()); + } + String id = ctx.ID().getText(); + if (symbolTable.find(id) == null) { + LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + + ": " + id); + } + return new Mutable(id, e); + } + + @Override + public Node visitImmutable(CminusParser.ImmutableContext ctx) { + if (ctx.expression() != null) { + return new ParenExpression((Expression)visitExpression(ctx.expression())); + } + return visitChildren(ctx); + } + + @Override + public Node visitCall(CminusParser.CallContext ctx) { + final String id = ctx.ID().getText(); + final List<Expression> args = new ArrayList<>(); + for (CminusParser.ExpressionContext e : ctx.expression()) { + args.add((Expression)visitExpression(e)); + } + if (symbolTable.find(id) == null) { + LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + + ": " + id); + } + return new Call(id, args); + } + + @Override + public Node visitConstant(CminusParser.ConstantContext ctx) { + final Node node; + if (ctx.NUMCONST() != null) { + node = new NumConstant(Integer.parseInt(ctx.NUMCONST().getText())); + } else if (ctx.CHARCONST() != null) { + node = new CharConstant(ctx.CHARCONST().getText().charAt(0)); + } else if (ctx.STRINGCONST() != null) { + node = new StringConstant(ctx.STRINGCONST().getText()); + } else { + node = new BoolConstant(ctx.getText().equals("true")); } + return node; + } } |