summaryrefslogtreecommitdiff
path: root/submit/ASTVisitor.java
diff options
context:
space:
mode:
Diffstat (limited to 'submit/ASTVisitor.java')
-rw-r--r--submit/ASTVisitor.java304
1 files changed, 304 insertions, 0 deletions
diff --git a/submit/ASTVisitor.java b/submit/ASTVisitor.java
new file mode 100644
index 0000000..b8762b5
--- /dev/null
+++ b/submit/ASTVisitor.java
@@ -0,0 +1,304 @@
+package submit;
+
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import parser.CminusBaseVisitor;
+import parser.CminusParser;
+import submit.ast.*;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.logging.Logger;
+
+public class ASTVisitor extends CminusBaseVisitor<Node> {
+ private final Logger LOGGER;
+ private SymbolTable symbolTable;
+
+ public ASTVisitor(Logger LOGGER) {
+ this.LOGGER = LOGGER;
+ }
+
+ public SymbolTable getSymbolTable() {
+ return symbolTable;
+ }
+
+ private VarType getVarType(CminusParser.TypeSpecifierContext ctx) {
+ final String t = ctx.getText();
+ return (t.equals("int")) ? VarType.INT : (t.equals("bool")) ? VarType.BOOL : VarType.CHAR;
+ }
+
+ @Override
+ public Node visitProgram(CminusParser.ProgramContext ctx) {
+ symbolTable = new SymbolTable();
+ List<Declaration> decls = new ArrayList<>();
+ for (CminusParser.DeclarationContext d : ctx.declaration()) {
+ decls.add((Declaration) visitDeclaration(d));
+ }
+ return new Program(decls);
+ }
+
+ @Override
+ public Node visitVarDeclaration(CminusParser.VarDeclarationContext ctx) {
+ VarType type = getVarType(ctx.typeSpecifier());
+ List<String> ids = new ArrayList<>();
+ List<Integer> arraySizes = new ArrayList<>();
+ for (CminusParser.VarDeclIdContext v : ctx.varDeclId()) {
+ String id = v.ID().getText();
+ ids.add(id);
+ symbolTable.addSymbol(id, new SymbolInfo(id, type, false));
+ if (v.NUMCONST() != null) {
+ arraySizes.add(Integer.parseInt(v.NUMCONST().getText()));
+ } else {
+ arraySizes.add(-1);
+ }
+ }
+ final boolean isStatic = false;
+ return new VarDeclaration(type, ids, arraySizes, isStatic);
+ }
+
+ @Override
+ public Node visitFunDeclaration(CminusParser.FunDeclarationContext ctx) {
+ VarType returnType = null;
+ if (ctx.typeSpecifier() != null) {
+ returnType = getVarType(ctx.typeSpecifier());
+ }
+ String id = ctx.ID().getText();
+ List<Param> params = new ArrayList<>();
+ for (CminusParser.ParamContext p : ctx.param()) {
+ params.add((Param) visitParam(p));
+ }
+ Statement statement = (Statement) visitStatement(ctx.statement());
+ symbolTable.addSymbol(id, new SymbolInfo(id, returnType, true));
+ return new FunDeclaration(returnType, id, params, statement);
+ }
+
+ @Override
+ public Node visitParam(CminusParser.ParamContext ctx) {
+ VarType type = getVarType(ctx.typeSpecifier());
+ String id = ctx.paramId().ID().getText();
+ symbolTable.addSymbol(id, new SymbolInfo(id, type, false));
+ return new Param(type, id, ctx.paramId().children.size() > 1);
+ }
+
+ @Override
+ public Node visitCompoundStmt(CminusParser.CompoundStmtContext ctx) {
+ symbolTable = symbolTable.createChild();
+ List<Statement> statements = new ArrayList<>();
+ for (CminusParser.VarDeclarationContext d : ctx.varDeclaration()) {
+ statements.add((VarDeclaration) visitVarDeclaration(d));
+ }
+ for (CminusParser.StatementContext d : ctx.statement()) {
+ statements.add((Statement) visitStatement(d));
+ }
+ symbolTable = symbolTable.getParent();
+ return new CompoundStatement(statements);
+ }
+
+ @Override
+ public Node visitExpressionStmt(CminusParser.ExpressionStmtContext ctx) {
+ if (ctx.expression() == null) {
+ return Statement.empty();
+ }
+ return new ExpressionStatement((Expression) visitExpression(ctx.expression()));
+ }
+
+ @Override
+ public Node visitIfStmt(CminusParser.IfStmtContext ctx) {
+ Expression expression = (Expression) visitSimpleExpression(ctx.simpleExpression());
+ Statement trueStatement = (Statement) visitStatement(ctx.statement(0));
+ Statement falseStatement = null;
+ if (ctx.statement().size() > 1) {
+ falseStatement = (Statement) visitStatement(ctx.statement(1));
+ }
+ return new If(expression, trueStatement, falseStatement);
+ }
+
+ @Override
+ public Node visitWhileStmt(CminusParser.WhileStmtContext ctx) {
+ Expression expression = (Expression) visitSimpleExpression(ctx.simpleExpression());
+ Statement statement = (Statement) visitStatement(ctx.statement());
+ return new While(expression, statement);
+ }
+
+ @Override
+ public Node visitReturnStmt(CminusParser.ReturnStmtContext ctx) {
+ if (ctx.expression() != null) {
+ return new Return((Expression) visitExpression(ctx.expression()));
+ }
+ return new Return(null);
+ }
+
+ @Override
+ public Node visitBreakStmt(CminusParser.BreakStmtContext ctx) {
+ return new Break();
+ }
+
+ @Override
+ public Node visitExpression(CminusParser.ExpressionContext ctx) {
+ final Node ret;
+ CminusParser.MutableContext mutable = ctx.mutable();
+ CminusParser.ExpressionContext expression = ctx.expression();
+ if (mutable != null) {
+ // Assignment
+ ParseTree operator = ctx.getChild(1);
+ Mutable lhs = (Mutable) visitMutable(mutable);// new Mutable(mutable.ID().getText(), (Expression)
+ // visitExpression(mutable.expression()));
+ Expression rhs = null;
+ if (expression != null) {
+ rhs = (Expression) visitExpression(expression);
+ }
+ ret = new Assignment(lhs, operator.getText(), rhs);
+ } else {
+ ret = visitSimpleExpression(ctx.simpleExpression());
+ }
+ return ret;
+ }
+
+ @Override
+ public Node visitOrExpression(CminusParser.OrExpressionContext ctx) {
+ List<Node> ands = new ArrayList<>();
+ for (CminusParser.AndExpressionContext and : ctx.andExpression()) {
+ ands.add(visitAndExpression(and));
+ }
+ if (ands.size() == 1) {
+ return ands.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression) ands.get(0), "||", (Expression) ands.get(1));
+ for (int i = 2; i < ands.size(); ++i) {
+ op = new BinaryOperator(op, "||", (Expression) ands.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitAndExpression(CminusParser.AndExpressionContext ctx) {
+ List<Node> uns = new ArrayList<>();
+ for (CminusParser.UnaryRelExpressionContext un : ctx.unaryRelExpression()) {
+ uns.add(visitUnaryRelExpression(un));
+ }
+ if (uns.size() == 1) {
+ return uns.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression) uns.get(0), "&&", (Expression) uns.get(1));
+ for (int i = 2; i < uns.size(); ++i) {
+ op = new BinaryOperator(op, "&&", (Expression) uns.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitUnaryRelExpression(CminusParser.UnaryRelExpressionContext ctx) {
+ Expression e = (Expression) visitRelExpression(ctx.relExpression());
+ for (TerminalNode n : ctx.BANG()) {
+ e = new UnaryOperator("!", e);
+ }
+ return e;
+ }
+
+ @Override
+ public Node visitRelExpression(CminusParser.RelExpressionContext ctx) {
+ List<Node> uns = new ArrayList<>();
+ for (CminusParser.SumExpressionContext un : ctx.sumExpression()) {
+ uns.add(visitSumExpression(un));
+ }
+ if (uns.size() == 1) {
+ return uns.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression) uns.get(0), ctx.relop(0).getText(),
+ (Expression) uns.get(1));
+ for (int i = 2; i < uns.size(); ++i) {
+ op = new BinaryOperator(op, ctx.relop(i - 1).getText(), (Expression) uns.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitSumExpression(CminusParser.SumExpressionContext ctx) {
+ List<Node> es = new ArrayList<>();
+ for (CminusParser.TermExpressionContext e : ctx.termExpression()) {
+ es.add(visitTermExpression(e));
+ }
+ if (es.size() == 1) {
+ return es.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression) es.get(0), ctx.sumop(0).getText(), (Expression) es.get(1));
+ for (int i = 2; i < es.size(); ++i) {
+ op = new BinaryOperator(op, ctx.sumop(i - 1).getText(), (Expression) es.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitTermExpression(CminusParser.TermExpressionContext ctx) {
+ List<Node> es = new ArrayList<>();
+ for (CminusParser.UnaryExpressionContext e : ctx.unaryExpression()) {
+ es.add(visitUnaryExpression(e));
+ }
+ if (es.size() == 1) {
+ return es.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression) es.get(0), ctx.mulop(0).getText(), (Expression) es.get(1));
+ for (int i = 2; i < es.size(); ++i) {
+ op = new BinaryOperator(op, ctx.mulop(i - 1).getText(), (Expression) es.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitUnaryExpression(CminusParser.UnaryExpressionContext ctx) {
+ Node ret = visitFactor(ctx.factor());
+ for (int i = ctx.unaryop().size() - 1; i >= 0; i--) {
+ ret = new UnaryOperator(ctx.unaryop(i).getText(), (Expression) ret);
+ }
+ return ret;
+ }
+
+ @Override
+ public Node visitMutable(CminusParser.MutableContext ctx) {
+ Expression e = null;
+ if (ctx.expression() != null) {
+ e = (Expression) visitExpression(ctx.expression());
+ }
+ String id = ctx.ID().getText();
+ if (symbolTable.find(id) == null) {
+ LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + ": " + id);
+ }
+ return new Mutable(id, e);
+ }
+
+ @Override
+ public Node visitImmutable(CminusParser.ImmutableContext ctx) {
+ if (ctx.expression() != null) {
+ return new ParenExpression((Expression) visitExpression(ctx.expression()));
+ }
+ return visitChildren(ctx);
+ }
+
+ @Override
+ public Node visitCall(CminusParser.CallContext ctx) {
+ final String id = ctx.ID().getText();
+ final List<Expression> args = new ArrayList<>();
+ for (CminusParser.ExpressionContext e : ctx.expression()) {
+ args.add((Expression) visitExpression(e));
+ }
+ if (symbolTable.find(id) == null) {
+ LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + ": " + id);
+ }
+ return new Call(id, args);
+ }
+
+ @Override
+ public Node visitConstant(CminusParser.ConstantContext ctx) {
+ final Node node;
+ if (ctx.NUMCONST() != null) {
+ node = new NumConstant(Integer.parseInt(ctx.NUMCONST().getText()));
+ } else if (ctx.CHARCONST() != null) {
+ node = new CharConstant(ctx.CHARCONST().getText().charAt(0));
+ } else if (ctx.STRINGCONST() != null) {
+ node = new StringConstant(ctx.STRINGCONST().getText());
+ } else {
+ node = new BoolConstant(ctx.getText().equals("true"));
+ }
+ return node;
+ }
+}