package org.mockitoutil;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import org.objenesis.Objenesis;
import org.objenesis.ObjenesisStd;
import org.objenesis.instantiator.ObjectInstantiator;
import static java.lang.String.format;
import static java.util.Arrays.asList;
public abstract class ClassLoaders {
protected ClassLoader parent = currentClassLoader();
protected ClassLoaders() {
}
public static IsolatedURLClassLoaderBuilder isolatedClassLoader() {
return new IsolatedURLClassLoaderBuilder();
}
public static ExcludingURLClassLoaderBuilder excludingClassLoader() {
return new ExcludingURLClassLoaderBuilder();
}
public static InMemoryClassLoaderBuilder inMemoryClassLoader() {
return new InMemoryClassLoaderBuilder();
}
public static ReachableClassesFinder in(ClassLoader classLoader) {
return new ReachableClassesFinder(classLoader);
}
public static ClassLoader jdkClassLoader() {
return String.class.getClassLoader();
}
public static ClassLoader systemClassLoader() {
return ClassLoader.getSystemClassLoader();
}
public static ClassLoader currentClassLoader() {
return ClassLoaders.class.getClassLoader();
}
public abstract ClassLoader build();
public static Class<?>[] coverageTool() {
HashSet<Class<?>> classes = new HashSet<Class<?>>();
classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector"));
classes.add(safeGetClass("org.slf4j.LoggerFactory"));
classes.remove(null);
return classes.toArray(new Class<?>[classes.size()]);
}
private static Class<?> safeGetClass(String className) {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
return null;
}
}
public static ClassLoaderExecutor using(final ClassLoader classLoader) {
return new ClassLoaderExecutor(classLoader);
}
public static class ClassLoaderExecutor {
private ClassLoader classLoader;
public ClassLoaderExecutor(ClassLoader classLoader) {
this.classLoader = classLoader;
}
public void execute(final Runnable task) throws Exception {
ExecutorService executorService = Executors.newSingleThreadExecutor(new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread thread = Executors.defaultThreadFactory().newThread(r);
thread.setContextClassLoader(classLoader);
return thread;
}
});
try {
Future<?> taskFuture = executorService.submit(new Runnable() {
@Override
public void run() {
try {
reloadTaskInClassLoader(task).run();
} catch (Throwable throwable) {
throw new IllegalStateException(format("Given task could not be loaded properly in the given classloader '%s', error '%s",
task,
throwable.getMessage()),
throwable);
}
}
});
taskFuture.get();
executorService.shutdownNow();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
throw this.<Exception>unwrapAndThrows(e);
}
}
@SuppressWarnings("unchecked")
private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T {
throw (T) ex.getCause();
}
Runnable reloadTaskInClassLoader(Runnable task) {
try {
@SuppressWarnings("unchecked")
Class<Runnable> taskClassReloaded = (Class<Runnable>) classLoader.loadClass(task.getClass().getName());
Objenesis objenesis = new ObjenesisStd();
ObjectInstantiator<Runnable> thingyInstantiator = objenesis.getInstantiatorOf(taskClassReloaded);
Runnable reloaded = thingyInstantiator.newInstance();
// lenient shallow copy of class compatible fields
for (Field field : task.getClass().getDeclaredFields()) {
Field declaredField = taskClassReloaded.getDeclaredField(field.getName());
int modifiers = declaredField.getModifiers();
if(Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
// Skip static final fields (e.g. jacoco fields)
// otherwise IllegalAccessException (can be bypassed with Unsafe though)
// We may also miss coverage data.
continue;
}
if (declaredField.getType() == field.getType()) { // don't copy this
field.setAccessible(true);
declaredField.setAccessible(true);
declaredField.set(reloaded, field.get(task));
}
}
return reloaded;
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
throw new IllegalStateException(e);
} catch (NoSuchFieldException e) {
throw new IllegalStateException(e);
}
}
}
public static class IsolatedURLClassLoaderBuilder extends ClassLoaders {
private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>();
private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) {
privateCopyPrefixes.addAll(asList(privatePrefixes));
return this;
}
public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
codeSourceUrls.addAll(pathsToURLs(urls));
return this;
}
public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
for (Class<?> clazz : classes) {
codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
}
return this;
}
public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() {
codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
return this;
}
public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) {
excludedPrefixes.addAll(asList(privatePrefixes));
return this;
}
public ClassLoader build() {
return new LocalIsolatedURLClassLoader(
jdkClassLoader(),
codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
privateCopyPrefixes,
excludedPrefixes
);
}
}
static class LocalIsolatedURLClassLoader extends URLClassLoader {
private final ArrayList<String> privateCopyPrefixes;
private final ArrayList<String> excludedPrefixes;
LocalIsolatedURLClassLoader(ClassLoader classLoader,
URL[] urls,
ArrayList<String> privateCopyPrefixes,
ArrayList<String> excludedPrefixes) {
super(urls, classLoader);
this.privateCopyPrefixes = privateCopyPrefixes;
this.excludedPrefixes = excludedPrefixes;
}
@Override
public Class<?> findClass(String name) throws ClassNotFoundException {
if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) {
throw new ClassNotFoundException(format("Can only load classes with prefixes : %s, but not : %s",
privateCopyPrefixes,
excludedPrefixes));
}
try {
return super.findClass(name);
} catch (ClassNotFoundException cnfe) {
throw new ClassNotFoundException(format("%s%n%s%n",
cnfe.getMessage(),
" Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"),
cnfe);
}
}
private boolean classShouldBePrivate(String name) {
for (String prefix : privateCopyPrefixes) {
if (name.startsWith(prefix)) return true;
}
return false;
}
private boolean classShouldBeExcluded(String name) {
for (String prefix : excludedPrefixes) {
if (name.startsWith(prefix)) return true;
}
return false;
}
}
public static class ExcludingURLClassLoaderBuilder extends ClassLoaders {
private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) {
excludedPrefixes.addAll(asList(privatePrefixes));
return this;
}
public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
codeSourceUrls.addAll(pathsToURLs(urls));
return this;
}
public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
for (Class<?> clazz : classes) {
codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
}
return this;
}
public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() {
codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
return this;
}
public ClassLoader build() {
return new LocalExcludingURLClassLoader(
jdkClassLoader(),
codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
excludedPrefixes
);
}
}
static class LocalExcludingURLClassLoader extends URLClassLoader {
private final ArrayList<String> excludedPrefixes;
LocalExcludingURLClassLoader(ClassLoader classLoader,
URL[] urls,
ArrayList<String> excludedPrefixes) {
super(urls, classLoader);
this.excludedPrefixes = excludedPrefixes;
}
@Override
public Class<?> findClass(String name) throws ClassNotFoundException {
if (classShouldBeExcluded(name))
throw new ClassNotFoundException("classes with prefix : " + excludedPrefixes + " are excluded");
return super.findClass(name);
}
private boolean classShouldBeExcluded(String name) {
for (String prefix : excludedPrefixes) {
if (name.startsWith(prefix)) return true;
}
return false;
}
}
public static class InMemoryClassLoaderBuilder extends ClassLoaders {
private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
public InMemoryClassLoaderBuilder withParent(ClassLoader parent) {
this.parent = parent;
return this;
}
public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) {
inMemoryClassObjects.put(name, classDefinition);
return this;
}
public ClassLoader build() {
return new InMemoryClassLoader(parent, inMemoryClassObjects);
}
}
static class InMemoryClassLoader extends ClassLoader {
public static final String SCHEME = "mem";
private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) {
super(parent);
this.inMemoryClassObjects = inMemoryClassObjects;
}
protected Class<?> findClass(String name) throws ClassNotFoundException {
byte[] classDefinition = inMemoryClassObjects.get(name);
if (classDefinition != null) {
return defineClass(name, classDefinition, 0, classDefinition.length);
}
throw new ClassNotFoundException(name);
}
@Override
public Enumeration<URL> getResources(String ignored) throws IOException {
return inMemoryOnly();
}
private Enumeration<URL> inMemoryOnly() {
final Set<String> names = inMemoryClassObjects.keySet();
return new Enumeration<URL>() {
private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this);
private final Iterator<String> it = names.iterator();
public boolean hasMoreElements() {
return it.hasNext();
}
public URL nextElement() {
try {
return new URL(null, SCHEME + ":" + it.next(), memHandler);
} catch (MalformedURLException rethrown) {
throw new IllegalStateException(rethrown);
}
}
};
}
}
public static class MemHandler extends URLStreamHandler {
private InMemoryClassLoader inMemoryClassLoader;
public MemHandler(InMemoryClassLoader inMemoryClassLoader) {
this.inMemoryClassLoader = inMemoryClassLoader;
}
@Override
protected URLConnection openConnection(URL url) throws IOException {
return new MemURLConnection(url, inMemoryClassLoader);
}
private static class MemURLConnection extends URLConnection {
private final InMemoryClassLoader inMemoryClassLoader;
private String qualifiedName;
public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) {
super(url);
this.inMemoryClassLoader = inMemoryClassLoader;
qualifiedName = url.getPath();
}
@Override
public void connect() throws IOException {
}
@Override
public InputStream getInputStream() throws IOException {
return new ByteArrayInputStream(inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName));
}
}
}
URL obtainCurrentClassPathOf(String className) {
String path = className.replace('.', '/') + ".class";
String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm();
try {
return new URL(url.substring(0, url.length() - path.length()));
} catch (MalformedURLException e) {
throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
}
}
List<URL> pathsToURLs(String... codeSourceUrls) {
return pathsToURLs(Arrays.asList(codeSourceUrls));
}
private List<URL> pathsToURLs(List<String> codeSourceUrls) {
ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size());
for (String codeSourceUrl : codeSourceUrls) {
URL url = pathToUrl(codeSourceUrl);
urls.add(url);
}
return urls;
}
private URL pathToUrl(String path) {
try {
return new File(path).getAbsoluteFile().toURI().toURL();
} catch (MalformedURLException e) {
throw new IllegalArgumentException("Path is malformed", e);
}
}
public static class ReachableClassesFinder {
private ClassLoader classLoader;
private Set<String> qualifiedNameSubstring = new HashSet<String>();
ReachableClassesFinder(ClassLoader classLoader) {
this.classLoader = classLoader;
}
public ReachableClassesFinder omit(String... qualifiedNameSubstring) {
this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring));
return this;
}
public Set<String> listOwnedClasses() throws IOException, URISyntaxException {
Enumeration<URL> roots = classLoader.getResources("");
Set<String> classes = new HashSet<String>();
while (roots.hasMoreElements()) {
URI uri = roots.nextElement().toURI();
if (uri.getScheme().equalsIgnoreCase("file")) {
addFromFileBasedClassLoader(classes, uri);
} else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) {
addFromInMemoryBasedClassLoader(classes, uri);
} else {
throw new IllegalArgumentException(format("Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory", classLoader));
}
}
return classes;
}
private void addFromFileBasedClassLoader(Set<String> classes, URI uri) {
File root = new File(uri);
classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring));
}
private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) {
String qualifiedName = uri.getSchemeSpecificPart();
if (excludes(qualifiedName, qualifiedNameSubstring)) {
classes.add(qualifiedName);
}
}
private Set<String> findClassQualifiedNames(File root, File file, Set<String> packageFilters) {
if (file.isDirectory()) {
File[] files = file.listFiles();
Set<String> classes = new HashSet<String>();
for (File children : files) {
classes.addAll(findClassQualifiedNames(root, children, packageFilters));
}
return classes;
} else {
if (file.getName().endsWith(".class")) {
String qualifiedName = classNameFor(root, file);
if (excludes(qualifiedName, packageFilters)) {
return Collections.singleton(qualifiedName);
}
}
}
return Collections.emptySet();
}
private boolean excludes(String qualifiedName, Set<String> packageFilters) {
for (String filter : packageFilters) {
if (qualifiedName.contains(filter)) return false;
}
return true;
}
private String classNameFor(File root, File file) {
String temp = file.getAbsolutePath().substring(root.getAbsolutePath().length() + 1).
replace(File.separatorChar, '.');
return temp.subSequence(0, temp.indexOf(".class")).toString();
}
}
}