Java程序  |  151行  |  5.87 KB

package junitparams.internal.parameters;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.junit.runners.model.FrameworkMethod;

import junitparams.Parameters;

class ParamsFromMethodCommon {
    private FrameworkMethod frameworkMethod;

    ParamsFromMethodCommon(FrameworkMethod frameworkMethod) {
        this.frameworkMethod = frameworkMethod;
    }

    Object[] paramsFromMethod(Class<?> sourceClass) {
        String methodAnnotation = frameworkMethod.getAnnotation(Parameters.class).method();

        if (methodAnnotation.isEmpty()) {
            return invokeMethodWithParams(defaultMethodName(), sourceClass);
        }

        List<Object> result = new ArrayList<Object>();
        for (String methodName : methodAnnotation.split(",")) {
            for (Object param : invokeMethodWithParams(methodName.trim(), sourceClass))
                result.add(param);
        }

        return result.toArray();
    }

    Object[] getDataFromMethod(Method providerMethod) throws IllegalAccessException, InvocationTargetException {
        return encapsulateParamsIntoArrayIfSingleParamsetPassed((Object[]) providerMethod.invoke(null));
    }

    boolean containsDefaultParametersProvidingMethod(Class<?> sourceClass) {
        return findMethodInTestClassHierarchy(defaultMethodName(), sourceClass) != null;
    }

    private String defaultMethodName() {
        return "parametersFor" + frameworkMethod.getName().substring(0, 1).toUpperCase()
                + this.frameworkMethod.getName().substring(1);
    }

    private Object[] invokeMethodWithParams(String methodName, Class<?> sourceClass) {
        Method providerMethod = findMethodInTestClassHierarchy(methodName, sourceClass);
        if (providerMethod == null) {
            throw new RuntimeException("Could not find method: " + methodName + " so no params were used.");
        }

        return invokeParamsProvidingMethod(providerMethod, sourceClass);
    }

    @SuppressWarnings("unchecked")
    private Object[] invokeParamsProvidingMethod(Method provideMethod, Class<?> sourceClass) {
        try {
            Object testObject = sourceClass.newInstance();
            provideMethod.setAccessible(true);
            Object result = provideMethod.invoke(testObject);

            if (Object[].class.isAssignableFrom(result.getClass())) {
                Object[] params = (Object[]) result;
                return encapsulateParamsIntoArrayIfSingleParamsetPassed(params);
            }

            if (Iterable.class.isAssignableFrom(result.getClass())) {
                try {
                    ArrayList<Object[]> res = new ArrayList<Object[]>();
                    for (Object[] paramSet : (Iterable<Object[]>) result)
                        res.add(paramSet);
                    return res.toArray();
                } catch (ClassCastException e1) {
                    // Iterable with consecutive paramsets, each of one param
                    ArrayList<Object> res = new ArrayList<Object>();
                    for (Object param : (Iterable<?>) result)
                        res.add(new Object[]{param});
                    return res.toArray();
                }
            }

            if (Iterator.class.isAssignableFrom(result.getClass())) {
                Object iteratedElement = null;
                try {
                    ArrayList<Object[]> res = new ArrayList<Object[]>();
                    Iterator<Object[]> iterator = (Iterator<Object[]>) result;
                    while (iterator.hasNext()) {
                        iteratedElement = iterator.next();
                        // ClassCastException will occur in the following line
                        // if the iterator is actually Iterator<Object> in Java 7
                        res.add((Object[]) iteratedElement);
                    }
                    return res.toArray();
                } catch (ClassCastException e1) {
                    // Iterator with consecutive paramsets, each of one param
                    ArrayList<Object> res = new ArrayList<Object>();
                    Iterator<?> iterator = (Iterator<?>) result;
                    // The first element is already stored in iteratedElement
                    res.add(iteratedElement);
                    while (iterator.hasNext()) {
                        res.add(new Object[]{iterator.next()});
                    }
                    return res.toArray();
                }
            }

            throw new ClassCastException();

        } catch (ClassCastException e) {
            throw new RuntimeException("The return type of: " + provideMethod.getName() + " defined in class " +
                    sourceClass + " is not Object[][] nor Iterable<Object[]>. Fix it!", e);
        } catch (Exception e) {
            throw new RuntimeException("Could not invoke method: " + provideMethod.getName() + " defined in class " +
                    sourceClass + " so no params were used.", e);
        }
    }

    private Method findMethodInTestClassHierarchy(String methodName, Class<?> sourceClass) {
        Class<?> declaringClass = sourceClass;
        while (declaringClass.getSuperclass() != null) {
            try {
                return declaringClass.getDeclaredMethod(methodName);
            } catch (Exception ignore) {
            }
            declaringClass = declaringClass.getSuperclass();
        }
        return null;
    }

    private Object[] encapsulateParamsIntoArrayIfSingleParamsetPassed(Object[] params) {
        if (frameworkMethod.getMethod().getParameterTypes().length != params.length) {
            return params;
        }

        if (params.length == 0) {
            return params;
        }

        Object param = params[0];
        if (param == null || !param.getClass().isArray()) {
            return new Object[]{params};
        }

        return params;
    }

}