Java程序  |  211行  |  5.11 KB

/*
 * Copyright (C) 2010 The Guava Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.google.common.collect.testing;

import java.io.Serializable;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

/**
 * A wrapper around {@code TreeMap} that aggressively checks to see if keys are
 * mutually comparable. This implementation passes the navigable map test
 * suites.
 *
 * @author Louis Wasserman
 */
public final class SafeTreeMap<K, V> implements Serializable, SortedMap<K, V> {
  @SuppressWarnings("unchecked")
  private static final Comparator<Object> NATURAL_ORDER = new Comparator<Object>() {
    @Override public int compare(Object o1, Object o2) {
      return ((Comparable<Object>) o1).compareTo(o2);
    }
  };
  private final SortedMap<K, V> delegate;

  public SafeTreeMap() {
    this(new TreeMap<K, V>());
  }

  public SafeTreeMap(Comparator<? super K> comparator) {
    this(new TreeMap<K, V>(comparator));
  }

  public SafeTreeMap(Map<? extends K, ? extends V> map) {
    this(new TreeMap<K, V>(map));
  }

  private SafeTreeMap(SortedMap<K, V> delegate) {
    this.delegate = delegate;
    if (delegate == null) {
      throw new NullPointerException();
    }
    for (K k : keySet()) {
      checkValid(k);
    }
  }

  @Override public void clear() {
    delegate.clear();
  }

  @SuppressWarnings("unchecked")
  @Override public Comparator<? super K> comparator() {
    Comparator<? super K> comparator = delegate.comparator();
    if (comparator == null) {
      comparator = (Comparator<? super K>) NATURAL_ORDER;
    }
    return comparator;
  }

  @Override public boolean containsKey(Object key) {
    try {
      return delegate.containsKey(checkValid(key));
    } catch (NullPointerException e) {
      return false;
    } catch (ClassCastException e) {
      return false;
    }
  }

  @Override public boolean containsValue(Object value) {
    return delegate.containsValue(value);
  }

  @Override public Set<Entry<K, V>> entrySet() {
    return new AbstractSet<Entry<K, V>>() {
      private Set<Entry<K, V>> delegate() {
        return delegate.entrySet();
      }

      @Override
      public boolean contains(Object object) {
        try {
          return delegate().contains(object);
        } catch (NullPointerException e) {
          return false;
        } catch (ClassCastException e) {
          return false;
        }
      }

      @Override
      public Iterator<Entry<K, V>> iterator() {
        return delegate().iterator();
      }

      @Override
      public int size() {
        return delegate().size();
      }

      @Override
      public boolean remove(Object o) {
        return delegate().remove(o);
      }

      @Override
      public void clear() {
        delegate().clear();
      }
    };
  }

  @Override public K firstKey() {
    return delegate.firstKey();
  }

  @Override public V get(Object key) {
    return delegate.get(checkValid(key));
  }

  @Override public SortedMap<K, V> headMap(K toKey) {
    return new SafeTreeMap<K, V>(delegate.headMap(checkValid(toKey)));
  }

  @Override public boolean isEmpty() {
    return delegate.isEmpty();
  }

  @Override public Set<K> keySet() {
    return delegate.keySet();
  }

  @Override public K lastKey() {
    return delegate.lastKey();
  }

  @Override public V put(K key, V value) {
    return delegate.put(checkValid(key), value);
  }

  @Override public void putAll(Map<? extends K, ? extends V> map) {
    for (K key : map.keySet()) {
      checkValid(key);
    }
    delegate.putAll(map);
  }

  @Override public V remove(Object key) {
    return delegate.remove(checkValid(key));
  }

  @Override public int size() {
    return delegate.size();
  }

  @Override public SortedMap<K, V> subMap(K fromKey, K toKey) {
    return new SafeTreeMap<K, V>(delegate.subMap(checkValid(fromKey), checkValid(toKey)));
  }

  @Override public SortedMap<K, V> tailMap(K fromKey) {
    return new SafeTreeMap<K, V>(delegate.tailMap(checkValid(fromKey)));
  }

  @Override public Collection<V> values() {
    return delegate.values();
  }

  private <T> T checkValid(T t) {
    // a ClassCastException is what's supposed to happen!
    @SuppressWarnings("unchecked")
    K k = (K) t;
    comparator().compare(k, k);
    return t;
  }

  @Override public boolean equals(Object obj) {
    return delegate.equals(obj);
  }

  @Override public int hashCode() {
    return delegate.hashCode();
  }

  @Override public String toString() {
    return delegate.toString();
  }

  private static final long serialVersionUID = 0L;

}