Skip to content
Snippets Groups Projects
ResultSetMapper.java 15.2 KiB
Newer Older
package com.trimaral.orm;

import com.google.common.base.CharMatcher;
import com.google.common.collect.ObjectArrays;
import com.google.common.primitives.Primitives;
import com.google.i18n.phonenumbers.NumberParseException;
import com.trimaral.orm.annotations.Column;
import com.trimaral.orm.annotations.Table;
import com.trimaral.orm.exceptions.AnnotationNotFoundException;
import com.trimaral.orm.exceptions.DAOException;
import com.trimaral.orm.util.ClassUtil;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.*;
import java.time.YearMonth;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringJoiner;

/**
 * @param <T> An entity class
 * @author Lorenzo Ferron
 * @version 2019.09.28
 */
public class ResultSetMapper<T> {

    public static final String SET_CLAUSE = " SET ";
    public static final String WHERE_CLAUSE = " WHERE ";
    public static final String VALUES_CLAUSE = " VALUES ";
    public static final String FROM_CLAUSE = " FROM ";

    public static final String SELECT_SQL = "SELECT *" + FROM_CLAUSE;
    public static final String INSERT_SQL = "INSERT ";
    public static final String IGNORE_SQL = "IGNORE ";
    public static final String INTO_SQL = "INTO ";
    public static final String UPDATE_SQL = "UPDATE ";
    public static final String DELETE_SQL = "DELETE";
    public static final String DELIMITER = ", ";
    private static final String REPLACE_SQL = " ON DUPLICATE KEY UPDATE ";

    private final DataSource dataSource;
    private final Class<T> clazz;
    private Table table;
    private Field[] fields;
    private List<Field> primaryKeyField = new ArrayList<>(0);

    public ResultSetMapper(DataSource dataSource, Class<T> clazz) {
        this.dataSource = dataSource;
        this.clazz = clazz;

        if (clazz != null) {
            if (!clazz.isAnnotationPresent(Table.class))
                throw new AnnotationNotFoundException();
            table = clazz.getAnnotation(Table.class);
            fields = ClassUtil.getAnnotatedDeclaredFields(clazz, Column.class);
            for (Field field : fields)
                if (field.getAnnotation(Column.class).isPrimaryKey())
                    primaryKeyField.add(field);
            /*if (primaryKeyField.isEmpty()) {
                primaryKeyFromParent(clazz.getSuperclass());
                fields = ObjectArrays.concat(fields, primaryKeyField.toArray(new Field[0]), Field.class);
            }*/
        }
    }

    private static void questionMarksParamsCount(String criteria, Object[] params) {
        int criteriaCount = criteria == null ? 0 : CharMatcher.is('?').countIn(criteria);
        int paramsCount = params != null ? params.length : 0;
        if (criteriaCount != paramsCount)
            throw new DAOException("?: criteria = " + criteriaCount + "; params = " + paramsCount);
    }

    /*private void primaryKeyFromParent(Class<? super T> parent) {
        if (parent.equals(Object.class)) // caso base
            return;
        Field[] parentFields = ClassUtil.getAnnotatedDeclaredFields(parent, Column.class);
        for (Field field : parentFields)
            if (field.getAnnotation(Column.class).isPrimaryKey())
                primaryKeyField.add(field);
        primaryKeyFromParent(parent.getSuperclass());
    }*/

    @SuppressWarnings({"unchecked"})
    public <T> List<T> findByCriteria(String criteria, Object... params) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        questionMarksParamsCount(criteria, params);

        String sqlStatement = SELECT_SQL + table.name();

        if (criteria != null)
            sqlStatement += WHERE_CLAUSE + criteria.trim();

        List<T> result = new ArrayList<>(0);

        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sqlStatement)) {
            if (criteria != null) for (int i = 0; i < params.length; i++) stmt.setObject(i + 1, params[i]);
            System.out.println(stmt.toString()); // For debug purpose
            try (ResultSet rs = stmt.executeQuery()) {
                while (rs.next()) {
                    T item = (T) clazz.newInstance();
                    for (Field field : fields) {
                        Object value;
                        Class<?> type = field.getType();
                        value = type.equals(YearMonth.class) ? YearMonth.parse(rs.getString(field.getAnnotation(Column.class).name()), DateTimeFormatter.ofPattern("yyyy-MM-00")) : rs.getObject(field.getAnnotation(Column.class).name());
                        if (type.isPrimitive()) {
                            Class<?> boxed = Primitives.wrap(type);
                            value = boxed.cast(value);
                        }
                        field.setAccessible(true);
                        field.set(item, value);
                        field.setAccessible(false);
                    }
                    result.add(item);
                }
            }
        } catch (IllegalAccessException | InstantiationException e) {
            e.printStackTrace();
        }
        return result;
    }

    public String findById() {
        if (primaryKeyField.isEmpty())
            throw new NullPointerException("Primary Key is not found");
        else if (primaryKeyField.size() > 1)
            throw new UnsupportedOperationException("Only one primary key");
        return primaryKeyField.get(0).getAnnotation(Column.class).name() + " = ?";
    }

    public void deleteById(Long id) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        if (primaryKeyField.isEmpty())
            throw new NullPointerException("Primary Key is not found");
        else if (primaryKeyField.size() > 1)
            throw new UnsupportedOperationException("Only one primary key");
        deleteByCriteria(primaryKeyField.get(0).getAnnotation(Column.class).name() + " = ?", id);
    }

    public void getByCriteria(String criteria, Object... params) {
        questionMarksParamsCount(criteria, params);
    }

    public void update(T item) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        if (primaryKeyField.isEmpty())
            throw new NullPointerException("Primary Key is not found");
        StringJoiner joiner = new StringJoiner(" AND ");
        Object[] params = criteriaJoiner(item, joiner);
        updateByCriteria(item, joiner.toString(), params);
    }

    public void updateByCriteria(T item, String criteria, Object... params) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        questionMarksParamsCount(criteria, params);

        StringJoiner joiner = new StringJoiner(DELIMITER);

        List<Field> filteredFields = new ArrayList<>(0);
        Column column;
        for (Field field : fields) {
            field.setAccessible(true);
            column = field.getAnnotation(Column.class);
            if ((field.getDeclaringClass().equals(clazz) || column.isPrimaryKey()) && !column.hold()) {
                joiner.add(field.getAnnotation(Column.class).name() + " = ?");
                filteredFields.add(field);
            }
        }

        String sqlStatement = UPDATE_SQL + table.name() + SET_CLAUSE + joiner.toString() +
                WHERE_CLAUSE + criteria.trim();

        executeStatement(sqlStatement, item, filteredFields, params);
    }

    private void executeStatement(String sqlStatement, T item, List<Field> filteredFields, Object... params) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        try {
            boolean keysGeneration = primaryKeyField.stream().allMatch(x -> {
                try {
                    return x.get(item) == null;
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                    return false;
                }
            });
            try (Connection conn = dataSource.getConnection();
                 PreparedStatement stmt = keysGeneration ? conn.prepareStatement(sqlStatement, Statement.RETURN_GENERATED_KEYS) : conn.prepareStatement(sqlStatement)) {
                int counter = 1;
                for (Field field : filteredFields) {
                    if (field.getType().equals(YearMonth.class)) {
                        String date = ((YearMonth) field.get(item)).format(DateTimeFormatter.ofPattern("yyyy-MM-00"));
                        stmt.setString(counter++, field.getAnnotation(Column.class).isNullable() && "".equals(date) ? null : date);
                    } else
                        stmt.setObject(counter++, field.getAnnotation(Column.class).isNullable() && "".equals(field.get(item)) ? null : field.get(item));
                }
                for (Object param : params)
                    stmt.setObject(counter++, param);

                System.out.println(stmt.toString()); // For debug purpose

                int affectedRows = stmt.executeUpdate();
                if (keysGeneration) {
                    if (affectedRows == 0)
                        throw new SQLException("No rows affected.");
                    try (ResultSet generatedKeys = stmt.getGeneratedKeys()) {
                        if (generatedKeys.next())
                            Arrays.stream(fields).filter(p -> p.getAnnotation(Column.class).isPrimaryKey()).findFirst().get().set(item, generatedKeys.getLong(1));
                        else
                            throw new SQLException("No ID obtained.");
                    }
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } finally {
            filteredFields.clear();
            for (Field field : fields)
                field.setAccessible(false);
        }
    }

    public void save(T item, boolean replaceMode, boolean ignore) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        try {
            StringJoiner joiner = new StringJoiner(DELIMITER, " ( ", " ) ");

            List<Field> filteredFields = new ArrayList<>(0);
            Column column;
            for (Field field : fields) {
                field.setAccessible(true);
                column = field.getAnnotation(Column.class);
                if ((field.getDeclaringClass().equals(clazz) || column.isPrimaryKey()) && ((replaceMode || ignore) && column.isPrimaryKey() && field.get(item) != null || !(field.get(item) == null && (column.isPrimaryKey() || column.hasDefaultValue())))) {
                    joiner.add(field.getAnnotation(Column.class).name());
                    filteredFields.add(field);
                }
            }

            String sqlStatement = INSERT_SQL;
            if (ignore)
                sqlStatement += IGNORE_SQL;
            sqlStatement += INTO_SQL + table.name() + joiner.toString() + VALUES_CLAUSE;

            joiner = new StringJoiner(DELIMITER, " (", ")");

            for (int i = 0; i < filteredFields.size(); i++)
                joiner.add("?");

            sqlStatement += joiner.toString();

            if (replaceMode) {
                sqlStatement += REPLACE_SQL;
                joiner = new StringJoiner(DELIMITER + " ");
                for (Field field : filteredFields) {
                    String nameColumn = field.getAnnotation(Column.class).name();
                    joiner.add(nameColumn + "=" + VALUES_CLAUSE + "(" + nameColumn + ")");
                }
                sqlStatement += joiner.toString();
            }
            executeStatement(sqlStatement, item, filteredFields);
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }

    private Object[] criteriaJoiner(T item, StringJoiner joiner) {
        Object[] params = new Object[0];
        for (Field field : primaryKeyField) {
            joiner.add(field.getAnnotation(Column.class).name() + " = ?");
            field.setAccessible(true);
            try {
                params = ObjectArrays.concat(params, field.get(item));
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } finally {
                field.setAccessible(false);
            }
        }
        return params;
    }

    public void delete(T item) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        if (primaryKeyField.isEmpty())
            throw new NullPointerException("Primary Key is not found");
        StringJoiner joiner = new StringJoiner(" AND ");
        Object[] params = criteriaJoiner(item, joiner);
        deleteByCriteria(joiner.toString(), params);
    }

    public void deleteByCriteria(String criteria, Object... params) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        questionMarksParamsCount(criteria, params);

        String sqlStatement = DELETE_SQL + FROM_CLAUSE + table.name() + WHERE_CLAUSE + criteria.trim();

        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(sqlStatement)) {
            for (int i = 0; i < params.length; i++) stmt.setObject(i + 1, params[i]);
            System.out.println(stmt.toString()); // For debug purpose
            stmt.executeUpdate();
        }
    }

    public Object customQuery(String query, Object... params) throws SQLException, IllegalArgumentException, NullPointerException, NumberParseException {
        questionMarksParamsCount(query, params);

        try (Connection conn = dataSource.getConnection();
             PreparedStatement stmt = conn.prepareStatement(query, Statement.RETURN_GENERATED_KEYS)) {
            for (int i = 0; i < params.length; i++) stmt.setObject(i + 1, params[i]);
            System.out.println(stmt.toString()); // For debug purpose
            if (stmt.execute()) {
                try (ResultSet rs = stmt.getResultSet()) {
                    List<T> items = new ArrayList<>(0);
                    while (rs.next()) {
                        T item = clazz.newInstance();
                        for (Field field : fields) {
                            Object value;
                            Class<?> type = field.getType();
                            value = type.equals(YearMonth.class) ? YearMonth.parse(rs.getString(field.getAnnotation(Column.class).name()), DateTimeFormatter.ofPattern("yyyy-MM-00")) : rs.getObject(field.getAnnotation(Column.class).name());
                            if (type.isPrimitive()) {
                                Class<?> boxed = Primitives.wrap(type);
                                value = boxed.cast(value);
                            }
                            field.setAccessible(true);
                            field.set(item, value);
                            field.setAccessible(false);
                        }
                        items.add(item);
                    }
                    return items;
                }
            } else {
                try (ResultSet rs = stmt.getGeneratedKeys()) {
                    if (rs.next())
                        return rs.getLong(1);
                }
            }
        } catch (IllegalAccessException | InstantiationException e) {
            e.printStackTrace();
        }
        return null;
    }
}