Java程序  |  527行  |  19.94 KB

package org.mockitoutil;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import org.objenesis.Objenesis;
import org.objenesis.ObjenesisStd;
import org.objenesis.instantiator.ObjectInstantiator;

import static java.lang.String.format;
import static java.util.Arrays.asList;

public abstract class ClassLoaders {
    protected ClassLoader parent = currentClassLoader();

    protected ClassLoaders() {
    }

    public static IsolatedURLClassLoaderBuilder isolatedClassLoader() {
        return new IsolatedURLClassLoaderBuilder();
    }

    public static ExcludingURLClassLoaderBuilder excludingClassLoader() {
        return new ExcludingURLClassLoaderBuilder();
    }

    public static InMemoryClassLoaderBuilder inMemoryClassLoader() {
        return new InMemoryClassLoaderBuilder();
    }

    public static ReachableClassesFinder in(ClassLoader classLoader) {
        return new ReachableClassesFinder(classLoader);
    }

    public static ClassLoader jdkClassLoader() {
        return String.class.getClassLoader();
    }

    public static ClassLoader systemClassLoader() {
        return ClassLoader.getSystemClassLoader();
    }

    public static ClassLoader currentClassLoader() {
        return ClassLoaders.class.getClassLoader();
    }

    public abstract ClassLoader build();

    public static Class<?>[] coverageTool() {
        HashSet<Class<?>> classes = new HashSet<Class<?>>();
        classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector"));
        classes.add(safeGetClass("org.slf4j.LoggerFactory"));

        classes.remove(null);
        return classes.toArray(new Class<?>[classes.size()]);
    }

    private static Class<?> safeGetClass(String className) {
        try {
            return Class.forName(className);
        } catch (ClassNotFoundException e) {
            return null;
        }
    }

    public static ClassLoaderExecutor using(final ClassLoader classLoader) {
        return new ClassLoaderExecutor(classLoader);
    }

    public static class ClassLoaderExecutor {
        private ClassLoader classLoader;

        public ClassLoaderExecutor(ClassLoader classLoader) {
            this.classLoader = classLoader;
        }

        public void execute(final Runnable task) throws Exception {
            ExecutorService executorService = Executors.newSingleThreadExecutor(new ThreadFactory() {
                @Override
                public Thread newThread(Runnable r) {
                    Thread thread = Executors.defaultThreadFactory().newThread(r);
                    thread.setContextClassLoader(classLoader);
                    return thread;
                }
            });
            try {
                Future<?> taskFuture = executorService.submit(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            reloadTaskInClassLoader(task).run();
                        } catch (Throwable throwable) {
                            throw new IllegalStateException(format("Given task could not be loaded properly in the given classloader '%s', error '%s",
                                                                   task,
                                                                   throwable.getMessage()),
                                                            throwable);
                        }
                    }
                });
                taskFuture.get();
                executorService.shutdownNow();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch (ExecutionException e) {
                throw this.<Exception>unwrapAndThrows(e);
            }
        }

        @SuppressWarnings("unchecked")
        private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T {
            throw (T) ex.getCause();
        }

        Runnable reloadTaskInClassLoader(Runnable task) {
            try {
                @SuppressWarnings("unchecked")
                Class<Runnable> taskClassReloaded = (Class<Runnable>) classLoader.loadClass(task.getClass().getName());

                Objenesis objenesis = new ObjenesisStd();
                ObjectInstantiator<Runnable> thingyInstantiator = objenesis.getInstantiatorOf(taskClassReloaded);
                Runnable reloaded = thingyInstantiator.newInstance();

                // lenient shallow copy of class compatible fields
                for (Field field : task.getClass().getDeclaredFields()) {
                    Field declaredField = taskClassReloaded.getDeclaredField(field.getName());
                    int modifiers = declaredField.getModifiers();
                    if(Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
                        // Skip static final fields (e.g. jacoco fields)
                        // otherwise IllegalAccessException (can be bypassed with Unsafe though)
                        // We may also miss coverage data.
                        continue;
                    }
                    if (declaredField.getType() == field.getType()) { // don't copy this
                        field.setAccessible(true);
                        declaredField.setAccessible(true);
                        declaredField.set(reloaded, field.get(task));
                    }
                }

                return reloaded;
            } catch (ClassNotFoundException e) {
                throw new IllegalStateException(e);
            } catch (IllegalAccessException e) {
                throw new IllegalStateException(e);
            } catch (NoSuchFieldException e) {
                throw new IllegalStateException(e);
            }
        }
    }

    public static class IsolatedURLClassLoaderBuilder extends ClassLoaders {
        private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
        private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>();
        private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();

        public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) {
            privateCopyPrefixes.addAll(asList(privatePrefixes));
            return this;
        }

        public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
            codeSourceUrls.addAll(pathsToURLs(urls));
            return this;
        }

        public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
            for (Class<?> clazz : classes) {
                codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
            }
            return this;
        }

        public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() {
            codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
            return this;
        }

        public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) {
            excludedPrefixes.addAll(asList(privatePrefixes));
            return this;
        }

        public ClassLoader build() {
            return new LocalIsolatedURLClassLoader(
                    jdkClassLoader(),
                    codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
                    privateCopyPrefixes,
                    excludedPrefixes
            );
        }
    }

    static class LocalIsolatedURLClassLoader extends URLClassLoader {
        private final ArrayList<String> privateCopyPrefixes;
        private final ArrayList<String> excludedPrefixes;

        LocalIsolatedURLClassLoader(ClassLoader classLoader,
                                    URL[] urls,
                                    ArrayList<String> privateCopyPrefixes,
                                    ArrayList<String> excludedPrefixes) {
            super(urls, classLoader);
            this.privateCopyPrefixes = privateCopyPrefixes;
            this.excludedPrefixes = excludedPrefixes;
        }

        @Override
        public Class<?> findClass(String name) throws ClassNotFoundException {
            if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) {
                throw new ClassNotFoundException(format("Can only load classes with prefixes : %s, but not : %s",
                                                        privateCopyPrefixes,
                                                        excludedPrefixes));
            }
            try {
                return super.findClass(name);
            } catch (ClassNotFoundException cnfe) {
                throw new ClassNotFoundException(format("%s%n%s%n",
                                                        cnfe.getMessage(),
                                                        "    Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"),
                                                 cnfe);
            }
        }

        private boolean classShouldBePrivate(String name) {
            for (String prefix : privateCopyPrefixes) {
                if (name.startsWith(prefix)) return true;
            }
            return false;
        }

        private boolean classShouldBeExcluded(String name) {
            for (String prefix : excludedPrefixes) {
                if (name.startsWith(prefix)) return true;
            }
            return false;
        }
    }

    public static class ExcludingURLClassLoaderBuilder extends ClassLoaders {
        private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
        private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();

        public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) {
            excludedPrefixes.addAll(asList(privatePrefixes));
            return this;
        }

        public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
            codeSourceUrls.addAll(pathsToURLs(urls));
            return this;
        }

        public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
            for (Class<?> clazz : classes) {
                codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
            }
            return this;
        }

        public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() {
            codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
            return this;
        }

        public ClassLoader build() {
            return new LocalExcludingURLClassLoader(
                    jdkClassLoader(),
                    codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
                    excludedPrefixes
            );
        }
    }

    static class LocalExcludingURLClassLoader extends URLClassLoader {
        private final ArrayList<String> excludedPrefixes;

        LocalExcludingURLClassLoader(ClassLoader classLoader,
                                     URL[] urls,
                                     ArrayList<String> excludedPrefixes) {
            super(urls, classLoader);
            this.excludedPrefixes = excludedPrefixes;
        }

        @Override
        public Class<?> findClass(String name) throws ClassNotFoundException {
            if (classShouldBeExcluded(name))
                throw new ClassNotFoundException("classes with prefix : " + excludedPrefixes + " are excluded");
            return super.findClass(name);
        }

        private boolean classShouldBeExcluded(String name) {
            for (String prefix : excludedPrefixes) {
                if (name.startsWith(prefix)) return true;
            }
            return false;
        }
    }

    public static class InMemoryClassLoaderBuilder extends ClassLoaders {
        private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();

        public InMemoryClassLoaderBuilder withParent(ClassLoader parent) {
            this.parent = parent;
            return this;
        }

        public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) {
            inMemoryClassObjects.put(name, classDefinition);
            return this;
        }

        public ClassLoader build() {
            return new InMemoryClassLoader(parent, inMemoryClassObjects);
        }
    }

    static class InMemoryClassLoader extends ClassLoader {
        public static final String SCHEME = "mem";
        private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();

        public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) {
            super(parent);
            this.inMemoryClassObjects = inMemoryClassObjects;
        }

        protected Class<?> findClass(String name) throws ClassNotFoundException {
            byte[] classDefinition = inMemoryClassObjects.get(name);
            if (classDefinition != null) {
                return defineClass(name, classDefinition, 0, classDefinition.length);
            }
            throw new ClassNotFoundException(name);
        }

        @Override
        public Enumeration<URL> getResources(String ignored) throws IOException {
            return inMemoryOnly();
        }

        private Enumeration<URL> inMemoryOnly() {
            final Set<String> names = inMemoryClassObjects.keySet();
            return new Enumeration<URL>() {
                private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this);
                private final Iterator<String> it = names.iterator();

                public boolean hasMoreElements() {
                    return it.hasNext();
                }

                public URL nextElement() {
                    try {
                        return new URL(null, SCHEME + ":" + it.next(), memHandler);
                    } catch (MalformedURLException rethrown) {
                        throw new IllegalStateException(rethrown);
                    }
                }
            };
        }
    }

    public static class MemHandler extends URLStreamHandler {
        private InMemoryClassLoader inMemoryClassLoader;

        public MemHandler(InMemoryClassLoader inMemoryClassLoader) {
            this.inMemoryClassLoader = inMemoryClassLoader;
        }

        @Override
        protected URLConnection openConnection(URL url) throws IOException {
            return new MemURLConnection(url, inMemoryClassLoader);
        }

        private static class MemURLConnection extends URLConnection {
            private final InMemoryClassLoader inMemoryClassLoader;
            private String qualifiedName;

            public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) {
                super(url);
                this.inMemoryClassLoader = inMemoryClassLoader;
                qualifiedName = url.getPath();
            }

            @Override
            public void connect() throws IOException {
            }

            @Override
            public InputStream getInputStream() throws IOException {
                return new ByteArrayInputStream(inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName));
            }
        }
    }

    URL obtainCurrentClassPathOf(String className) {
        String path = className.replace('.', '/') + ".class";
        String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm();

        try {
            return new URL(url.substring(0, url.length() - path.length()));
        } catch (MalformedURLException e) {
            throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
        }
    }

    List<URL> pathsToURLs(String... codeSourceUrls) {
        return pathsToURLs(Arrays.asList(codeSourceUrls));
    }

    private List<URL> pathsToURLs(List<String> codeSourceUrls) {
        ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size());
        for (String codeSourceUrl : codeSourceUrls) {
            URL url = pathToUrl(codeSourceUrl);
            urls.add(url);
        }
        return urls;
    }

    private URL pathToUrl(String path) {
        try {
            return new File(path).getAbsoluteFile().toURI().toURL();
        } catch (MalformedURLException e) {
            throw new IllegalArgumentException("Path is malformed", e);
        }
    }

    public static class ReachableClassesFinder {
        private ClassLoader classLoader;
        private Set<String> qualifiedNameSubstring = new HashSet<String>();

        ReachableClassesFinder(ClassLoader classLoader) {
            this.classLoader = classLoader;
        }

        public ReachableClassesFinder omit(String... qualifiedNameSubstring) {
            this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring));
            return this;
        }

        public Set<String> listOwnedClasses() throws IOException, URISyntaxException {
            Enumeration<URL> roots = classLoader.getResources("");

            Set<String> classes = new HashSet<String>();
            while (roots.hasMoreElements()) {
                URI uri = roots.nextElement().toURI();

                if (uri.getScheme().equalsIgnoreCase("file")) {
                    addFromFileBasedClassLoader(classes, uri);
                } else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) {
                    addFromInMemoryBasedClassLoader(classes, uri);
                } else {
                    throw new IllegalArgumentException(format("Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory", classLoader));
                }
            }
            return classes;
        }

        private void addFromFileBasedClassLoader(Set<String> classes, URI uri) {
            File root = new File(uri);
            classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring));
        }

        private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) {
            String qualifiedName = uri.getSchemeSpecificPart();
            if (excludes(qualifiedName, qualifiedNameSubstring)) {
                classes.add(qualifiedName);
            }
        }


        private Set<String> findClassQualifiedNames(File root, File file, Set<String> packageFilters) {
            if (file.isDirectory()) {
                File[] files = file.listFiles();
                Set<String> classes = new HashSet<String>();
                for (File children : files) {
                    classes.addAll(findClassQualifiedNames(root, children, packageFilters));
                }
                return classes;
            } else {
                if (file.getName().endsWith(".class")) {
                    String qualifiedName = classNameFor(root, file);
                    if (excludes(qualifiedName, packageFilters)) {
                        return Collections.singleton(qualifiedName);
                    }
                }
            }
            return Collections.emptySet();
        }

        private boolean excludes(String qualifiedName, Set<String> packageFilters) {
            for (String filter : packageFilters) {
                if (qualifiedName.contains(filter)) return false;
            }
            return true;
        }

        private String classNameFor(File root, File file) {
            String temp = file.getAbsolutePath().substring(root.getAbsolutePath().length() + 1).
                    replace(File.separatorChar, '.');
            return temp.subSequence(0, temp.indexOf(".class")).toString();
        }

    }
}