summaryrefslogtreecommitdiff
path: root/submit/ASTVisitor.java
diff options
context:
space:
mode:
Diffstat (limited to 'submit/ASTVisitor.java')
-rw-r--r--submit/ASTVisitor.java624
1 files changed, 331 insertions, 293 deletions
diff --git a/submit/ASTVisitor.java b/submit/ASTVisitor.java
index b8762b5..5cafbb1 100644
--- a/submit/ASTVisitor.java
+++ b/submit/ASTVisitor.java
@@ -1,304 +1,342 @@
package submit;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.logging.Logger;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.TerminalNode;
import parser.CminusBaseVisitor;
import parser.CminusParser;
import submit.ast.*;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.logging.Logger;
-
public class ASTVisitor extends CminusBaseVisitor<Node> {
- private final Logger LOGGER;
- private SymbolTable symbolTable;
-
- public ASTVisitor(Logger LOGGER) {
- this.LOGGER = LOGGER;
- }
-
- public SymbolTable getSymbolTable() {
- return symbolTable;
- }
-
- private VarType getVarType(CminusParser.TypeSpecifierContext ctx) {
- final String t = ctx.getText();
- return (t.equals("int")) ? VarType.INT : (t.equals("bool")) ? VarType.BOOL : VarType.CHAR;
- }
-
- @Override
- public Node visitProgram(CminusParser.ProgramContext ctx) {
- symbolTable = new SymbolTable();
- List<Declaration> decls = new ArrayList<>();
- for (CminusParser.DeclarationContext d : ctx.declaration()) {
- decls.add((Declaration) visitDeclaration(d));
- }
- return new Program(decls);
- }
-
- @Override
- public Node visitVarDeclaration(CminusParser.VarDeclarationContext ctx) {
- VarType type = getVarType(ctx.typeSpecifier());
- List<String> ids = new ArrayList<>();
- List<Integer> arraySizes = new ArrayList<>();
- for (CminusParser.VarDeclIdContext v : ctx.varDeclId()) {
- String id = v.ID().getText();
- ids.add(id);
- symbolTable.addSymbol(id, new SymbolInfo(id, type, false));
- if (v.NUMCONST() != null) {
- arraySizes.add(Integer.parseInt(v.NUMCONST().getText()));
- } else {
- arraySizes.add(-1);
- }
- }
- final boolean isStatic = false;
- return new VarDeclaration(type, ids, arraySizes, isStatic);
- }
-
- @Override
- public Node visitFunDeclaration(CminusParser.FunDeclarationContext ctx) {
- VarType returnType = null;
- if (ctx.typeSpecifier() != null) {
- returnType = getVarType(ctx.typeSpecifier());
- }
- String id = ctx.ID().getText();
- List<Param> params = new ArrayList<>();
- for (CminusParser.ParamContext p : ctx.param()) {
- params.add((Param) visitParam(p));
- }
- Statement statement = (Statement) visitStatement(ctx.statement());
- symbolTable.addSymbol(id, new SymbolInfo(id, returnType, true));
- return new FunDeclaration(returnType, id, params, statement);
- }
-
- @Override
- public Node visitParam(CminusParser.ParamContext ctx) {
- VarType type = getVarType(ctx.typeSpecifier());
- String id = ctx.paramId().ID().getText();
- symbolTable.addSymbol(id, new SymbolInfo(id, type, false));
- return new Param(type, id, ctx.paramId().children.size() > 1);
- }
-
- @Override
- public Node visitCompoundStmt(CminusParser.CompoundStmtContext ctx) {
- symbolTable = symbolTable.createChild();
- List<Statement> statements = new ArrayList<>();
- for (CminusParser.VarDeclarationContext d : ctx.varDeclaration()) {
- statements.add((VarDeclaration) visitVarDeclaration(d));
- }
- for (CminusParser.StatementContext d : ctx.statement()) {
- statements.add((Statement) visitStatement(d));
- }
- symbolTable = symbolTable.getParent();
- return new CompoundStatement(statements);
- }
-
- @Override
- public Node visitExpressionStmt(CminusParser.ExpressionStmtContext ctx) {
- if (ctx.expression() == null) {
- return Statement.empty();
- }
- return new ExpressionStatement((Expression) visitExpression(ctx.expression()));
- }
-
- @Override
- public Node visitIfStmt(CminusParser.IfStmtContext ctx) {
- Expression expression = (Expression) visitSimpleExpression(ctx.simpleExpression());
- Statement trueStatement = (Statement) visitStatement(ctx.statement(0));
- Statement falseStatement = null;
- if (ctx.statement().size() > 1) {
- falseStatement = (Statement) visitStatement(ctx.statement(1));
- }
- return new If(expression, trueStatement, falseStatement);
- }
-
- @Override
- public Node visitWhileStmt(CminusParser.WhileStmtContext ctx) {
- Expression expression = (Expression) visitSimpleExpression(ctx.simpleExpression());
- Statement statement = (Statement) visitStatement(ctx.statement());
- return new While(expression, statement);
- }
-
- @Override
- public Node visitReturnStmt(CminusParser.ReturnStmtContext ctx) {
- if (ctx.expression() != null) {
- return new Return((Expression) visitExpression(ctx.expression()));
- }
- return new Return(null);
- }
-
- @Override
- public Node visitBreakStmt(CminusParser.BreakStmtContext ctx) {
- return new Break();
- }
-
- @Override
- public Node visitExpression(CminusParser.ExpressionContext ctx) {
- final Node ret;
- CminusParser.MutableContext mutable = ctx.mutable();
- CminusParser.ExpressionContext expression = ctx.expression();
- if (mutable != null) {
- // Assignment
- ParseTree operator = ctx.getChild(1);
- Mutable lhs = (Mutable) visitMutable(mutable);// new Mutable(mutable.ID().getText(), (Expression)
- // visitExpression(mutable.expression()));
- Expression rhs = null;
- if (expression != null) {
- rhs = (Expression) visitExpression(expression);
- }
- ret = new Assignment(lhs, operator.getText(), rhs);
- } else {
- ret = visitSimpleExpression(ctx.simpleExpression());
- }
- return ret;
- }
-
- @Override
- public Node visitOrExpression(CminusParser.OrExpressionContext ctx) {
- List<Node> ands = new ArrayList<>();
- for (CminusParser.AndExpressionContext and : ctx.andExpression()) {
- ands.add(visitAndExpression(and));
- }
- if (ands.size() == 1) {
- return ands.get(0);
- }
- BinaryOperator op = new BinaryOperator((Expression) ands.get(0), "||", (Expression) ands.get(1));
- for (int i = 2; i < ands.size(); ++i) {
- op = new BinaryOperator(op, "||", (Expression) ands.get(i));
- }
- return op;
- }
-
- @Override
- public Node visitAndExpression(CminusParser.AndExpressionContext ctx) {
- List<Node> uns = new ArrayList<>();
- for (CminusParser.UnaryRelExpressionContext un : ctx.unaryRelExpression()) {
- uns.add(visitUnaryRelExpression(un));
- }
- if (uns.size() == 1) {
- return uns.get(0);
- }
- BinaryOperator op = new BinaryOperator((Expression) uns.get(0), "&&", (Expression) uns.get(1));
- for (int i = 2; i < uns.size(); ++i) {
- op = new BinaryOperator(op, "&&", (Expression) uns.get(i));
- }
- return op;
- }
-
- @Override
- public Node visitUnaryRelExpression(CminusParser.UnaryRelExpressionContext ctx) {
- Expression e = (Expression) visitRelExpression(ctx.relExpression());
- for (TerminalNode n : ctx.BANG()) {
- e = new UnaryOperator("!", e);
- }
- return e;
- }
-
- @Override
- public Node visitRelExpression(CminusParser.RelExpressionContext ctx) {
- List<Node> uns = new ArrayList<>();
- for (CminusParser.SumExpressionContext un : ctx.sumExpression()) {
- uns.add(visitSumExpression(un));
- }
- if (uns.size() == 1) {
- return uns.get(0);
- }
- BinaryOperator op = new BinaryOperator((Expression) uns.get(0), ctx.relop(0).getText(),
- (Expression) uns.get(1));
- for (int i = 2; i < uns.size(); ++i) {
- op = new BinaryOperator(op, ctx.relop(i - 1).getText(), (Expression) uns.get(i));
- }
- return op;
- }
-
- @Override
- public Node visitSumExpression(CminusParser.SumExpressionContext ctx) {
- List<Node> es = new ArrayList<>();
- for (CminusParser.TermExpressionContext e : ctx.termExpression()) {
- es.add(visitTermExpression(e));
- }
- if (es.size() == 1) {
- return es.get(0);
- }
- BinaryOperator op = new BinaryOperator((Expression) es.get(0), ctx.sumop(0).getText(), (Expression) es.get(1));
- for (int i = 2; i < es.size(); ++i) {
- op = new BinaryOperator(op, ctx.sumop(i - 1).getText(), (Expression) es.get(i));
- }
- return op;
- }
-
- @Override
- public Node visitTermExpression(CminusParser.TermExpressionContext ctx) {
- List<Node> es = new ArrayList<>();
- for (CminusParser.UnaryExpressionContext e : ctx.unaryExpression()) {
- es.add(visitUnaryExpression(e));
- }
- if (es.size() == 1) {
- return es.get(0);
- }
- BinaryOperator op = new BinaryOperator((Expression) es.get(0), ctx.mulop(0).getText(), (Expression) es.get(1));
- for (int i = 2; i < es.size(); ++i) {
- op = new BinaryOperator(op, ctx.mulop(i - 1).getText(), (Expression) es.get(i));
- }
- return op;
- }
-
- @Override
- public Node visitUnaryExpression(CminusParser.UnaryExpressionContext ctx) {
- Node ret = visitFactor(ctx.factor());
- for (int i = ctx.unaryop().size() - 1; i >= 0; i--) {
- ret = new UnaryOperator(ctx.unaryop(i).getText(), (Expression) ret);
- }
- return ret;
- }
-
- @Override
- public Node visitMutable(CminusParser.MutableContext ctx) {
- Expression e = null;
- if (ctx.expression() != null) {
- e = (Expression) visitExpression(ctx.expression());
- }
- String id = ctx.ID().getText();
- if (symbolTable.find(id) == null) {
- LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + ": " + id);
- }
- return new Mutable(id, e);
- }
-
- @Override
- public Node visitImmutable(CminusParser.ImmutableContext ctx) {
- if (ctx.expression() != null) {
- return new ParenExpression((Expression) visitExpression(ctx.expression()));
- }
- return visitChildren(ctx);
- }
-
- @Override
- public Node visitCall(CminusParser.CallContext ctx) {
- final String id = ctx.ID().getText();
- final List<Expression> args = new ArrayList<>();
- for (CminusParser.ExpressionContext e : ctx.expression()) {
- args.add((Expression) visitExpression(e));
- }
- if (symbolTable.find(id) == null) {
- LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() + ": " + id);
- }
- return new Call(id, args);
- }
-
- @Override
- public Node visitConstant(CminusParser.ConstantContext ctx) {
- final Node node;
- if (ctx.NUMCONST() != null) {
- node = new NumConstant(Integer.parseInt(ctx.NUMCONST().getText()));
- } else if (ctx.CHARCONST() != null) {
- node = new CharConstant(ctx.CHARCONST().getText().charAt(0));
- } else if (ctx.STRINGCONST() != null) {
- node = new StringConstant(ctx.STRINGCONST().getText());
- } else {
- node = new BoolConstant(ctx.getText().equals("true"));
- }
- return node;
+ private final Logger LOGGER;
+ private SymbolTable symbolTable;
+
+ public ASTVisitor(Logger LOGGER) { this.LOGGER = LOGGER; }
+
+ public SymbolTable getSymbolTable() { return symbolTable; }
+
+ private VarType getVarType(CminusParser.TypeSpecifierContext ctx) {
+ final String t = ctx.getText();
+ return (t.equals("int")) ? VarType.INT
+ : (t.equals("bool")) ? VarType.BOOL
+ : VarType.CHAR;
+ }
+
+ @Override
+ public Node visitProgram(CminusParser.ProgramContext ctx) {
+ symbolTable = new SymbolTable();
+ List<Declaration> decls = new ArrayList<>();
+ for (CminusParser.DeclarationContext d : ctx.declaration()) {
+ decls.add((Declaration)visitDeclaration(d));
+ }
+ return new Program(decls);
+ }
+
+ @Override
+ public Node visitVarDeclaration(CminusParser.VarDeclarationContext ctx) {
+ VarType type = getVarType(ctx.typeSpecifier());
+ List<String> ids = new ArrayList<>();
+ List<Integer> arraySizes = new ArrayList<>();
+ for (CminusParser.VarDeclIdContext v : ctx.varDeclId()) {
+ String id = v.ID().getText();
+ ids.add(id);
+ // symbolTable.addSymbol(id, new SymbolInfo(id, type, false));
+ if (v.NUMCONST() != null) {
+ arraySizes.add(Integer.parseInt(v.NUMCONST().getText()));
+ } else {
+ int offset = symbolTable.addOffset(1);
+
+ SymbolInfo symbol = new SymbolInfo(id, type, false, offset);
+ symbolTable.addSymbol(id, symbol);
+
+ arraySizes.add(-1);
+ }
+ }
+ final boolean isStatic = false;
+ return new VarDeclaration(type, ids, arraySizes, isStatic);
+ }
+
+ @Override
+ public Node visitFunDeclaration(CminusParser.FunDeclarationContext ctx) {
+ VarType returnType = null;
+ if (ctx.typeSpecifier() != null) {
+ returnType = getVarType(ctx.typeSpecifier());
+ }
+ String id = ctx.ID().getText();
+ symbolTable.addSymbol(id, new SymbolInfo(id, returnType, true));
+
+ List<Param> params = new ArrayList<>();
+
+ SymbolTable newTable = symbolTable.createChild();
+ symbolTable = newTable;
+
+ if (returnType != null) {
+ symbolTable.addSymbol("return",
+ new SymbolInfo("return", returnType, false,
+ symbolTable.addOffset(1)));
+ }
+
+ for (CminusParser.ParamContext p : ctx.param()) {
+ params.add((Param)visitParam(p));
+ }
+
+ CompoundStatement statement =
+ (CompoundStatement)visitStatement(ctx.statement());
+
+ statement.getSymbolTable().addOtherTableBefore(newTable);
+ symbolTable = newTable.getParent();
+
+ return new FunDeclaration(returnType, id, params, statement);
+ }
+
+ @Override
+ public Node visitParam(CminusParser.ParamContext ctx) {
+ VarType type = getVarType(ctx.typeSpecifier());
+ String id = ctx.paramId().ID().getText();
+ symbolTable.addSymbol(
+ id, new SymbolInfo(id, type, false, symbolTable.addOffset(1)));
+
+ return new Param(type, id, ctx.paramId().children.size() > 1);
+ }
+
+ @Override
+ public Node visitCompoundStmt(CminusParser.CompoundStmtContext ctx) {
+ SymbolTable child = symbolTable.createChild();
+ symbolTable = child;
+
+ List<Statement> statements = new ArrayList<>();
+ for (CminusParser.VarDeclarationContext d : ctx.varDeclaration()) {
+ statements.add((VarDeclaration)visitVarDeclaration(d));
+ }
+ for (CminusParser.StatementContext d : ctx.statement()) {
+ statements.add((Statement)visitStatement(d));
+ }
+ symbolTable = child.getParent();
+
+ return new CompoundStatement(statements, child);
+ }
+
+ @Override
+ public Node visitExpressionStmt(CminusParser.ExpressionStmtContext ctx) {
+ if (ctx.expression() == null) {
+ return Statement.empty();
+ }
+ return new ExpressionStatement(
+ (Expression)visitExpression(ctx.expression()));
+ }
+
+ @Override
+ public Node visitIfStmt(CminusParser.IfStmtContext ctx) {
+ Expression expression =
+ (Expression)visitSimpleExpression(ctx.simpleExpression());
+ Statement trueStatement = (Statement)visitStatement(ctx.statement(0));
+ Statement falseStatement = null;
+ if (ctx.statement().size() > 1) {
+ falseStatement = (Statement)visitStatement(ctx.statement(1));
+ }
+ return new If(expression, trueStatement, falseStatement);
+ }
+
+ @Override
+ public Node visitWhileStmt(CminusParser.WhileStmtContext ctx) {
+ Expression expression =
+ (Expression)visitSimpleExpression(ctx.simpleExpression());
+ Statement statement = (Statement)visitStatement(ctx.statement());
+ return new While(expression, statement);
+ }
+
+ @Override
+ public Node visitReturnStmt(CminusParser.ReturnStmtContext ctx) {
+ if (ctx.expression() != null) {
+ return new Return((Expression)visitExpression(ctx.expression()));
+ }
+ return new Return(null);
+ }
+
+ @Override
+ public Node visitBreakStmt(CminusParser.BreakStmtContext ctx) {
+ return new Break();
+ }
+
+ @Override
+ public Node visitExpression(CminusParser.ExpressionContext ctx) {
+ final Node ret;
+ CminusParser.MutableContext mutable = ctx.mutable();
+ CminusParser.ExpressionContext expression = ctx.expression();
+ if (mutable != null) {
+ // Assignment
+ ParseTree operator = ctx.getChild(1);
+ Mutable lhs = (Mutable)visitMutable(
+ mutable); // new Mutable(mutable.ID().getText(), (Expression)
+ // visitExpression(mutable.expression()));
+ Expression rhs = null;
+ if (expression != null) {
+ rhs = (Expression)visitExpression(expression);
+ }
+ ret = new Assignment(lhs, operator.getText(), rhs);
+ } else {
+ ret = visitSimpleExpression(ctx.simpleExpression());
+ }
+ return ret;
+ }
+
+ @Override
+ public Node visitOrExpression(CminusParser.OrExpressionContext ctx) {
+ List<Node> ands = new ArrayList<>();
+ for (CminusParser.AndExpressionContext and : ctx.andExpression()) {
+ ands.add(visitAndExpression(and));
+ }
+ if (ands.size() == 1) {
+ return ands.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression)ands.get(0), "||",
+ (Expression)ands.get(1));
+ for (int i = 2; i < ands.size(); ++i) {
+ op = new BinaryOperator(op, "||", (Expression)ands.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitAndExpression(CminusParser.AndExpressionContext ctx) {
+ List<Node> uns = new ArrayList<>();
+ for (CminusParser.UnaryRelExpressionContext un : ctx.unaryRelExpression()) {
+ uns.add(visitUnaryRelExpression(un));
+ }
+ if (uns.size() == 1) {
+ return uns.get(0);
+ }
+ BinaryOperator op = new BinaryOperator((Expression)uns.get(0), "&&",
+ (Expression)uns.get(1));
+ for (int i = 2; i < uns.size(); ++i) {
+ op = new BinaryOperator(op, "&&", (Expression)uns.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node
+ visitUnaryRelExpression(CminusParser.UnaryRelExpressionContext ctx) {
+ Expression e = (Expression)visitRelExpression(ctx.relExpression());
+ for (TerminalNode n : ctx.BANG()) {
+ e = new UnaryOperator("!", e);
+ }
+ return e;
+ }
+
+ @Override
+ public Node visitRelExpression(CminusParser.RelExpressionContext ctx) {
+ List<Node> uns = new ArrayList<>();
+ for (CminusParser.SumExpressionContext un : ctx.sumExpression()) {
+ uns.add(visitSumExpression(un));
+ }
+ if (uns.size() == 1) {
+ return uns.get(0);
+ }
+ BinaryOperator op = new BinaryOperator(
+ (Expression)uns.get(0), ctx.relop(0).getText(), (Expression)uns.get(1));
+ for (int i = 2; i < uns.size(); ++i) {
+ op = new BinaryOperator(op, ctx.relop(i - 1).getText(),
+ (Expression)uns.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitSumExpression(CminusParser.SumExpressionContext ctx) {
+ List<Node> es = new ArrayList<>();
+ for (CminusParser.TermExpressionContext e : ctx.termExpression()) {
+ es.add(visitTermExpression(e));
+ }
+ if (es.size() == 1) {
+ return es.get(0);
+ }
+ BinaryOperator op = new BinaryOperator(
+ (Expression)es.get(0), ctx.sumop(0).getText(), (Expression)es.get(1));
+ for (int i = 2; i < es.size(); ++i) {
+ op = new BinaryOperator(op, ctx.sumop(i - 1).getText(),
+ (Expression)es.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitTermExpression(CminusParser.TermExpressionContext ctx) {
+ List<Node> es = new ArrayList<>();
+ for (CminusParser.UnaryExpressionContext e : ctx.unaryExpression()) {
+ es.add(visitUnaryExpression(e));
+ }
+ if (es.size() == 1) {
+ return es.get(0);
+ }
+ BinaryOperator op = new BinaryOperator(
+ (Expression)es.get(0), ctx.mulop(0).getText(), (Expression)es.get(1));
+ for (int i = 2; i < es.size(); ++i) {
+ op = new BinaryOperator(op, ctx.mulop(i - 1).getText(),
+ (Expression)es.get(i));
+ }
+ return op;
+ }
+
+ @Override
+ public Node visitUnaryExpression(CminusParser.UnaryExpressionContext ctx) {
+ Node ret = visitFactor(ctx.factor());
+ for (int i = ctx.unaryop().size() - 1; i >= 0; i--) {
+ ret = new UnaryOperator(ctx.unaryop(i).getText(), (Expression)ret);
+ }
+ return ret;
+ }
+
+ @Override
+ public Node visitMutable(CminusParser.MutableContext ctx) {
+ Expression e = null;
+ if (ctx.expression() != null) {
+ e = (Expression)visitExpression(ctx.expression());
+ }
+ String id = ctx.ID().getText();
+ if (symbolTable.find(id) == null) {
+ LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() +
+ ": " + id);
+ }
+ return new Mutable(id, e);
+ }
+
+ @Override
+ public Node visitImmutable(CminusParser.ImmutableContext ctx) {
+ if (ctx.expression() != null) {
+ return new ParenExpression((Expression)visitExpression(ctx.expression()));
+ }
+ return visitChildren(ctx);
+ }
+
+ @Override
+ public Node visitCall(CminusParser.CallContext ctx) {
+ final String id = ctx.ID().getText();
+ final List<Expression> args = new ArrayList<>();
+ for (CminusParser.ExpressionContext e : ctx.expression()) {
+ args.add((Expression)visitExpression(e));
+ }
+ if (symbolTable.find(id) == null) {
+ LOGGER.warning("Undefined symbol on line " + ctx.getStart().getLine() +
+ ": " + id);
+ }
+ return new Call(id, args);
+ }
+
+ @Override
+ public Node visitConstant(CminusParser.ConstantContext ctx) {
+ final Node node;
+ if (ctx.NUMCONST() != null) {
+ node = new NumConstant(Integer.parseInt(ctx.NUMCONST().getText()));
+ } else if (ctx.CHARCONST() != null) {
+ node = new CharConstant(ctx.CHARCONST().getText().charAt(0));
+ } else if (ctx.STRINGCONST() != null) {
+ node = new StringConstant(ctx.STRINGCONST().getText());
+ } else {
+ node = new BoolConstant(ctx.getText().equals("true"));
}
+ return node;
+ }
}