LockFreeSet

java
Wrap Lines
Raw
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicStampedReference;
import java.util.ArrayList;
import java.util.List;

public class SetImpl<T extends Comparable<T>> implements Set<T> {

    private class Node {
        final T item;
        AtomicStampedReference<Node> nextRef;

        Node(T item) {
            this.item = item;
            this.nextRef = new AtomicStampedReference<>(null, 0);
        }

        Node(T item, Node next, int stamp) {
            this.item = item;
            this.nextRef = new AtomicStampedReference<>(next, stamp);
        }
    }

    private class Bounds {
        final Node lower, upper;

        Bounds(Node lower, Node upper) {
            this.lower = lower;
            this.upper = upper;
        }
    }

    private final Node head;

    public SetImpl() {
        head = new Node(null);
        head.nextRef = new AtomicStampedReference<>(null, 0);
    }

    @Override
    public boolean add(T value) {
        while (true) {
            Bounds bounds = find(value);
            Node prev = bounds.lower;
            Node curr = bounds.upper;

            if (curr != null && curr.item != null && value.compareTo(curr.item) == 0) {
                return false;
            }

            Node node = new Node(value, curr, 0);
            AtomicStampedReference<Node> setTo = (prev == null) ? head.nextRef : prev.nextRef;
            int[] stampHolder = {0};
            setTo.get(stampHolder);

            if (stampHolder[0] == -1) continue;

            if (setTo.compareAndSet(curr, node, stampHolder[0], stampHolder[0] + 1)) {
                return true;
            }
        }
    }

    @Override
    public boolean remove(T value) {
        while (true) {
            Bounds bounds = find(value);
            Node prev = bounds.lower;
            Node curr = bounds.upper;

            if (curr == null || curr.item == null || value.compareTo(curr.item) != 0) {
                return false;
            }

            int[] stamp = {0};
            Node next = curr.nextRef.get(stamp);

            if (stamp[0] == -1) {
                return false;
            }

            if (!curr.nextRef.compareAndSet(next, next, stamp[0], -1)) {
                continue;
            }

            AtomicStampedReference<Node> setTo = (prev == null) ? head.nextRef : prev.nextRef;
            int[] prevStamp = {0};
            Node beforeCur = setTo.get(prevStamp);
            setTo.compareAndSet(curr, next, prevStamp[0], prevStamp[0] + 1);

            return true;
        }
    }

    @Override
    public boolean contains(T value) {
        Node curr = head.nextRef.getReference();
        while (curr != null) {
            int[] stamp = {0};
            Node next = curr.nextRef.get(stamp);

            if (curr.item != null && curr.item.compareTo(value) >= 0) {
                return curr.item.compareTo(value) == 0 && stamp[0] != -1;
            }

            curr = next;
        }
        return false;
    }

    @Override
    public boolean isEmpty() {
        return iterator().hasNext();
    }

    @Override
    public Iterator<T> iterator() {
        return new Iterator<T>() {
            private final List<T> snapshot;
            private int index = 0;

            {
                snapshot = new ArrayList<>();
                while (true) {
                    List<Node> tmpNodes = new ArrayList<>();
                    List<Integer> tmpStamps = new ArrayList<>();
                    List<T> tmpSnapshot = new ArrayList<>();

                    Node curr = head.nextRef.getReference();

                    while (curr != null) {
                        int[] stamp = {0};
                        Node next = curr.nextRef.get(stamp);

                        if (stamp[0] != -1 && curr.item != null) {
                            tmpNodes.add(curr);
                            tmpStamps.add(stamp[0]);
                            tmpSnapshot.add(curr.item);
                        }

                        curr = next;
                    }

                    boolean consistent = true;
                    for (int i = 0; i < tmpNodes.size(); i++) {
                        if (tmpNodes.get(i).nextRef.getStamp() != tmpStamps.get(i)) {
                            consistent = false;
                            break;
                        }
                    }

                    if (consistent) {
                        snapshot.addAll(tmpSnapshot);
                        break;
                    }
                }
            }

            @Override
            public boolean hasNext() {
                return index < snapshot.size();
            }

            @Override
            public T next() {
                return snapshot.get(index++);
            }
        };
    }

    private Bounds find(T value) {
        retry:
        while (true) {
            Node prev = null;
            Node curr = head.nextRef.getReference();

            while (true) {
                if (curr == null) return new Bounds(prev, null);

                int[] stamp = {0};
                Node next = curr.nextRef.get(stamp);

                if (stamp[0] == -1) {
                    AtomicStampedReference<Node> setTo = (prev == null) ? head.nextRef : prev.nextRef;
                    int[] prevStamp = {0};
                    Node before = setTo.get(prevStamp);

                    if (before != curr || !setTo.compareAndSet(curr, next, prevStamp[0], prevStamp[0] + 1)) {
                        continue retry;
                    }

                    curr = next;
                    continue;
                }

                if (curr.item == null || curr.item.compareTo(value) < 0) {
                    prev = curr;
                    curr = next;
                } else {
                    return new Bounds(prev, curr);
                }
            }
        }
    }
}