LockFreeSet
java
Wrap Lines
Raw import java.util.Iterator;
import java.util.concurrent.atomic.AtomicStampedReference;
import java.util.ArrayList;
import java.util.List;
public class SetImpl<T extends Comparable<T>> implements Set<T> {
private class Node {
final T item;
AtomicStampedReference<Node> nextRef;
Node(T item) {
this.item = item;
this.nextRef = new AtomicStampedReference<>(null, 0);
}
Node(T item, Node next, int stamp) {
this.item = item;
this.nextRef = new AtomicStampedReference<>(next, stamp);
}
}
private class Bounds {
final Node lower, upper;
Bounds(Node lower, Node upper) {
this.lower = lower;
this.upper = upper;
}
}
private final Node head;
public SetImpl() {
head = new Node(null);
head.nextRef = new AtomicStampedReference<>(null, 0);
}
@Override
public boolean add(T value) {
while (true) {
Bounds bounds = find(value);
Node prev = bounds.lower;
Node curr = bounds.upper;
if (curr != null && curr.item != null && value.compareTo(curr.item) == 0) {
return false;
}
Node node = new Node(value, curr, 0);
AtomicStampedReference<Node> setTo = (prev == null) ? head.nextRef : prev.nextRef;
int[] stampHolder = {0};
setTo.get(stampHolder);
if (stampHolder[0] == -1) continue;
if (setTo.compareAndSet(curr, node, stampHolder[0], stampHolder[0] + 1)) {
return true;
}
}
}
@Override
public boolean remove(T value) {
while (true) {
Bounds bounds = find(value);
Node prev = bounds.lower;
Node curr = bounds.upper;
if (curr == null || curr.item == null || value.compareTo(curr.item) != 0) {
return false;
}
int[] stamp = {0};
Node next = curr.nextRef.get(stamp);
if (stamp[0] == -1) {
return false;
}
if (!curr.nextRef.compareAndSet(next, next, stamp[0], -1)) {
continue;
}
AtomicStampedReference<Node> setTo = (prev == null) ? head.nextRef : prev.nextRef;
int[] prevStamp = {0};
Node beforeCur = setTo.get(prevStamp);
setTo.compareAndSet(curr, next, prevStamp[0], prevStamp[0] + 1);
return true;
}
}
@Override
public boolean contains(T value) {
Node curr = head.nextRef.getReference();
while (curr != null) {
int[] stamp = {0};
Node next = curr.nextRef.get(stamp);
if (curr.item != null && curr.item.compareTo(value) >= 0) {
return curr.item.compareTo(value) == 0 && stamp[0] != -1;
}
curr = next;
}
return false;
}
@Override
public boolean isEmpty() {
return iterator().hasNext();
}
@Override
public Iterator<T> iterator() {
return new Iterator<T>() {
private final List<T> snapshot;
private int index = 0;
{
snapshot = new ArrayList<>();
while (true) {
List<Node> tmpNodes = new ArrayList<>();
List<Integer> tmpStamps = new ArrayList<>();
List<T> tmpSnapshot = new ArrayList<>();
Node curr = head.nextRef.getReference();
while (curr != null) {
int[] stamp = {0};
Node next = curr.nextRef.get(stamp);
if (stamp[0] != -1 && curr.item != null) {
tmpNodes.add(curr);
tmpStamps.add(stamp[0]);
tmpSnapshot.add(curr.item);
}
curr = next;
}
boolean consistent = true;
for (int i = 0; i < tmpNodes.size(); i++) {
if (tmpNodes.get(i).nextRef.getStamp() != tmpStamps.get(i)) {
consistent = false;
break;
}
}
if (consistent) {
snapshot.addAll(tmpSnapshot);
break;
}
}
}
@Override
public boolean hasNext() {
return index < snapshot.size();
}
@Override
public T next() {
return snapshot.get(index++);
}
};
}
private Bounds find(T value) {
retry:
while (true) {
Node prev = null;
Node curr = head.nextRef.getReference();
while (true) {
if (curr == null) return new Bounds(prev, null);
int[] stamp = {0};
Node next = curr.nextRef.get(stamp);
if (stamp[0] == -1) {
AtomicStampedReference<Node> setTo = (prev == null) ? head.nextRef : prev.nextRef;
int[] prevStamp = {0};
Node before = setTo.get(prevStamp);
if (before != curr || !setTo.compareAndSet(curr, next, prevStamp[0], prevStamp[0] + 1)) {
continue retry;
}
curr = next;
continue;
}
if (curr.item == null || curr.item.compareTo(value) < 0) {
prev = curr;
curr = next;
} else {
return new Bounds(prev, curr);
}
}
}
}
}