From 66f57c69b40656ba36e96285ac57217b2606cfa3 Mon Sep 17 00:00:00 2001 From: Manuel Rigger Date: Mon, 30 Aug 2021 09:54:44 +0200 Subject: [PATCH] Showcase how the reduction could look like --- pom.xml | 2 +- src/sqlancer/FoundBugException.java | 27 ++++ src/sqlancer/Main.java | 131 +++++++++++++++++- src/sqlancer/StateToReproduce.java | 6 +- .../gen/ddl/SQLite3IndexGenerator.java | 1 + .../sqlite3/oracle/SQLite3NoRECOracle.java | 55 ++++++-- 6 files changed, 204 insertions(+), 18 deletions(-) create mode 100644 src/sqlancer/FoundBugException.java diff --git a/pom.xml b/pom.xml index 1848ef107..ffa939b43 100644 --- a/pom.xml +++ b/pom.xml @@ -228,7 +228,7 @@ org.xerial sqlite-jdbc - 3.34.0 + 3.28.0 mysql diff --git a/src/sqlancer/FoundBugException.java b/src/sqlancer/FoundBugException.java new file mode 100644 index 000000000..458135b04 --- /dev/null +++ b/src/sqlancer/FoundBugException.java @@ -0,0 +1,27 @@ +package sqlancer; + +import sqlancer.sqlite3.SQLite3GlobalState; + +public class FoundBugException extends RuntimeException { + + public interface Reproducer { + public abstract boolean bugStillTriggers(SQLite3GlobalState globalState); + + public default void outputHook(SQLite3GlobalState globalState) { + + } + } + + private static final long serialVersionUID = 1L; + private Reproducer reproducer; + + public FoundBugException(String string, Reproducer reproducer) { + super(string); + this.reproducer = reproducer; + } + + public Reproducer getReproducer() { + return reproducer; + } + +} diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java index 4c1130f28..67cc08c7d 100644 --- a/src/sqlancer/Main.java +++ b/src/sqlancer/Main.java @@ -17,16 +17,20 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiFunction; +import java.util.stream.Collectors; import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander.Builder; +import sqlancer.FoundBugException.Reproducer; import sqlancer.arangodb.ArangoDBProvider; import sqlancer.citus.CitusProvider; import sqlancer.clickhouse.ClickHouseProvider; import sqlancer.cockroachdb.CockroachDBProvider; import sqlancer.common.log.Loggable; import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLancerResultSet; import sqlancer.cosmos.CosmosProvider; import sqlancer.duckdb.DuckDBProvider; @@ -35,6 +39,7 @@ import sqlancer.mongodb.MongoDBProvider; import sqlancer.mysql.MySQLProvider; import sqlancer.postgres.PostgresProvider; +import sqlancer.sqlite3.SQLite3GlobalState; import sqlancer.sqlite3.SQLite3Provider; import sqlancer.tidb.TiDBProvider; @@ -61,6 +66,8 @@ private Main() { public static final class StateLogger { private final File loggerFile; + private final File reducedFile; + private File curFile; private FileWriter logFileWriter; public FileWriter currentFileWriter; @@ -94,6 +101,7 @@ public StateLogger(String databaseName, DatabaseProvider provider, Main } ensureExistsAndIsEmpty(dir, provider); loggerFile = new File(dir, databaseName + ".log"); + reducedFile = new File(dir, databaseName + "-reduced.log"); logEachSelect = options.logEachSelect(); if (logEachSelect) { curFile = new File(dir, databaseName + "-cur.log"); @@ -135,6 +143,17 @@ private FileWriter getLogFileWriter() { return logFileWriter; } + private FileWriter getReducedWriter() { + // if (logFileWriter == null) { + try { + logFileWriter = new FileWriter(reducedFile); + } catch (IOException e) { + throw new AssertionError(e); + } + // } + return logFileWriter; + } + public FileWriter getCurrentFileWriter() { if (!logEachSelect) { throw new UnsupportedOperationException(); @@ -183,6 +202,17 @@ private void write(Loggable loggable) { } } + public void logReduced(StateToReproduce state) { + FileWriter logFileWriter = getReducedWriter(); + printState(logFileWriter, state); + try { + logFileWriter.close(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + public void logException(Throwable reduce, StateToReproduce state) { Loggable stackTrace = getStackTrace(reduce); FileWriter logFileWriter2 = getLogFileWriter(); @@ -300,6 +330,8 @@ public void testConnection() throws Exception { } } + boolean observedChange; + public void run() throws Exception { G state = createGlobalState(); stateToRepro = provider.getStateToReproduce(databaseName); @@ -323,14 +355,103 @@ public void run() throws Exception { if (options.logEachSelect()) { logger.writeCurrent(state.getState()); } - provider.generateAndTestDatabase(state); try { - logger.getCurrentFileWriter().close(); - logger.currentFileWriter = null; - } catch (IOException e) { - throw new AssertionError(e); + provider.generateAndTestDatabase(state); + + } catch (FoundBugException e) { + try { + logger.getCurrentFileWriter().close(); + logger.currentFileWriter = null; + } catch (IOException e2) { + throw new AssertionError(e2); + } + Reproducer reproducer = e.getReproducer(); + + G newGlobalState = createGlobalState(); + QueryManager newManager = new QueryManager<>(newGlobalState); + newGlobalState.setDatabaseName(databaseName); + newGlobalState.setMainOptions(options); + newGlobalState.setDmbsSpecificOptions(command); + newGlobalState.setStateLogger(new StateLogger(databaseName, provider, options)); + newGlobalState.setManager(newManager); + newGlobalState.setState(stateToRepro); + + List> knownToReproduceBugStatements = new ArrayList>(); + for (Query stat : state.getState().getStatements()) { + knownToReproduceBugStatements.add((Query) stat); + } + // iterate until fixpoint + do { + observedChange = false; + knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState, + knownToReproduceBugStatements, (candidateStatements, i) -> { + candidateStatements.remove((int) i); + return true; + }); + } while (observedChange); + + for (String s : new String[] { "OR IGNORE", "OR ABORT", "OR ROLLBACK", "OR FAIL", "TEMP", + "TEMPORARY", "UNIQUE", "NOT NULL", "COLLATE BINARY", "COLLATE NOCASE", "COLLATE RTRIM", + "INT", "REAL", "TEXT", "IF NOT EXISTS", "UNINDEXED" }) { + knownToReproduceBugStatements = tryReplaceToken(state, reproducer, newGlobalState, + knownToReproduceBugStatements, " " + s, ""); + } + throw e; + } + } + } + + private List> tryReplaceToken(G state, Reproducer reproducer, G newGlobalState, + List> knownToReproduceBugStatements, String target, String replaceBy) throws Exception { + do { + observedChange = false; + knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState, + knownToReproduceBugStatements, (candidateStatements, i) -> { + Query statement = candidateStatements.get(i); + if (statement.getQueryString().contains(target)) { + candidateStatements.set(i, (Query) new SQLQueryAdapter( + statement.getQueryString().replace(target, replaceBy), true)); + return true; + } + return false; + } + + ); + } while (observedChange); + return knownToReproduceBugStatements; + } + + private List> tryReduction(G state, Reproducer reproducer, G newGlobalState, + List> knownToReproduceBugStatements, + BiFunction>, Integer, Boolean> reductionOperation) throws Exception { + for (int i = 0; i < knownToReproduceBugStatements.size(); i++) { + try (C con2 = provider.createDatabase(newGlobalState)) { + newGlobalState.setConnection(con2); + List> candidateStatements = new ArrayList<>(knownToReproduceBugStatements); + if (!reductionOperation.apply(candidateStatements, i)) { + continue; + } + newGlobalState.getState().setStatements(candidateStatements.stream().collect(Collectors.toList())); + for (Query s : candidateStatements) { + try { + s.execute(newGlobalState); + } catch (Throwable ignoredException) { + // ignore + } + } + try { + if (reproducer.bugStillTriggers((SQLite3GlobalState) newGlobalState)) { + observedChange = true; + knownToReproduceBugStatements = candidateStatements; + reproducer.outputHook((SQLite3GlobalState) newGlobalState); + state.getLogger().logReduced(newGlobalState.getState()); + } + } catch (Throwable ignoredException) { + + } } } + return knownToReproduceBugStatements; } private G getInitializedGlobalState(long seed) { diff --git a/src/sqlancer/StateToReproduce.java b/src/sqlancer/StateToReproduce.java index f6408ce58..7ac3e9ccf 100644 --- a/src/sqlancer/StateToReproduce.java +++ b/src/sqlancer/StateToReproduce.java @@ -9,7 +9,7 @@ public class StateToReproduce { - private final List> statements = new ArrayList<>(); + private List> statements = new ArrayList<>(); private final String databaseName; @@ -128,4 +128,8 @@ public OracleRunReproductionState createLocalState() { return new OracleRunReproductionState(); } + public void setStatements(List> statements) { + this.statements = statements; + } + } diff --git a/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java b/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java index 0a33016fd..ec396b979 100644 --- a/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java +++ b/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java @@ -44,6 +44,7 @@ private SQLQueryAdapter create() throws SQLException { errors.add("non-deterministic use of date() in an index"); errors.add("non-deterministic use of datetime() in an index"); errors.add("The database file is locked"); + errors.add("Abort due to constraint violation"); SQLite3Errors.addExpectedExpressionErrors(errors); if (!SQLite3Provider.mustKnowResult) { // can only happen when PRAGMA case_sensitive_like=ON; diff --git a/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java b/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java index db9f55531..24337a14d 100644 --- a/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java +++ b/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java @@ -4,7 +4,10 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.function.Function; +import sqlancer.FoundBugException; +import sqlancer.FoundBugException.Reproducer; import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.oracle.NoRECBase; @@ -62,19 +65,35 @@ public void check() throws SQLException { select.setFromTables(tableRefs); select.setJoinClauses(joinStatements); - int optimizedCount = getOptimizedQuery(select, randomWhereCondition); - int unoptimizedCount = getUnoptimizedQuery(select, randomWhereCondition); + Function optimizedQuery = getOptimizedQuery(select, randomWhereCondition); + Function unoptimizedQuery = getUnoptimizedQuery(select, randomWhereCondition); + int optimizedCount = optimizedQuery.apply(state); + int unoptimizedCount = unoptimizedQuery.apply(state); if (optimizedCount == NO_VALID_RESULT || unoptimizedCount == NO_VALID_RESULT) { throw new IgnoreMeException(); } if (optimizedCount != unoptimizedCount) { state.getState().getLocalState().log(optimizedQueryString + ";\n" + unoptimizedQueryString + ";"); - throw new AssertionError(optimizedCount + " " + unoptimizedCount); + throw new FoundBugException((optimizedCount + " " + unoptimizedCount), new Reproducer() { + + @Override + public boolean bugStillTriggers(SQLite3GlobalState globalState) { + return optimizedQuery.apply(globalState) != unoptimizedQuery.apply(globalState); + } + + @Override + public void outputHook(SQLite3GlobalState globalState) { + globalState.getState().logStatement(new SQLQueryAdapter(optimizedQueryString)); + globalState.getState().logStatement(new SQLQueryAdapter(unoptimizedQueryString)); + } + + }); } } - private int getUnoptimizedQuery(SQLite3Select select, SQLite3Expression randomWhereCondition) throws SQLException { + private Function getUnoptimizedQuery(SQLite3Select select, + SQLite3Expression randomWhereCondition) throws SQLException { SQLite3PostfixUnaryOperation isTrue = new SQLite3PostfixUnaryOperation(PostfixUnaryOperator.IS_TRUE, randomWhereCondition); SQLite3PostfixText asText = new SQLite3PostfixText(isTrue, " as count", null); @@ -85,10 +104,17 @@ private int getUnoptimizedQuery(SQLite3Select select, SQLite3Expression randomWh logger.writeCurrent(unoptimizedQueryString); } SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - return extractCounts(q); + return new Function() { + + @Override + public Integer apply(SQLite3GlobalState state) { + return extractCounts(q, state); + } + }; } - private int getOptimizedQuery(SQLite3Select select, SQLite3Expression randomWhereCondition) throws SQLException { + private Function getOptimizedQuery(SQLite3Select select, + SQLite3Expression randomWhereCondition) throws SQLException { boolean useAggregate = Randomly.getBoolean(); if (Randomly.getBoolean()) { select.setOrderByExpressions(gen.generateOrderBys()); @@ -106,12 +132,19 @@ private int getOptimizedQuery(SQLite3Select select, SQLite3Expression randomWher logger.writeCurrent(optimizedQueryString); } SQLQueryAdapter q = new SQLQueryAdapter(optimizedQueryString, errors); - return useAggregate ? extractCounts(q) : countRows(q); + return new Function() { + + @Override + public Integer apply(SQLite3GlobalState state) { + return useAggregate ? extractCounts(q, state) : countRows(q, state); + } + + }; } - private int countRows(SQLQueryAdapter q) { + private int countRows(SQLQueryAdapter q, SQLite3GlobalState globalState) { int count = 0; - try (SQLancerResultSet rs = q.executeAndGet(state)) { + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { if (rs == null) { return NO_VALID_RESULT; } else { @@ -132,9 +165,9 @@ private int countRows(SQLQueryAdapter q) { return count; } - private int extractCounts(SQLQueryAdapter q) { + private int extractCounts(SQLQueryAdapter q, SQLite3GlobalState globalState) { int count = 0; - try (SQLancerResultSet rs = q.executeAndGet(state)) { + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { if (rs == null) { return NO_VALID_RESULT; } else {