Tuesday 28 June 2016

Direct Memory Access in Java

Direct Memory Access in Java

Java is a safe programming language and prevents programmer from doing mistakes mainly on memory management. 
Java contains a “backdoor” that provides a number of low-level operations to manipulate memory and threads directly. This backdoor class "sun.misc.Unsafe" is widely used by JDK itself. This article is a quick overview of "sun.misc.Unsafe" API and few interesting cases of its usage.

There is no simple way to create instance of Unsafe object like Unsafe unsafe = new Unsafe(), because it has private constructor.
if you try to call Unsafe.getUnsafe(), which is static method, will get SecurityException.
We has to write method for trusted code which checking that our code was loaded with primary classloader.
We can use  bootclasspath, which is too hard to make our code "trusted".

eg:- 

java -Xbootclasspath:/usr/jdk1.7.0_56/jre/lib/rt.jar:. com.sample.client.Unsafe

The easy way is to use Unsafe class instance called theUnsafe, which marked as private. In this case you may face error on your IDE. Ignore your IDE error's  and if the error is annoying, ignore errors on Unsafe usage in
Preferences -> Java -> Compiler -> Errors/Warnings -> Deprecated and restricted API -> Forbidden reference -> Warning

Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
Unsafe unsafe = (Unsafe) f.get(null);

There are, actually, few groups of important methods:

Memory

Holds direct memory access methods

getInt, putInt, copyMemory, freeMemory, getAddress, allocateMemory

Arrays

Holds Arrays manipulation

arrayIndexScale, arrayBaseOffset

Classes

It provides methods for classes and static field manipulation

defineClass, staticFieldOffset, defineAnonymousClass, ensureClassInitialized

Synchronization

For synchronization

monitorEnter, monitorExit, compareAndSwapInt, putOrderedInt, tryMonitorEnter

Info

Returns some low-level memory information

addressSize, pageSize

Objects

For object and its fields manipulation

allocateInstance, objectFieldOffset


Use Cases

Initialization

allocateInstance method can be useful when you need to skip object initialization phase or you want instance of that class but don't have any public constructor.

A o1 = new A(); // constructor
o1.a();

A o2 = A.class.newInstance(); // reflection
o2.a();

A o3 = (A) unsafe.allocateInstance(A.class); // unsafe
o3.a(); // prints 0

Memory corruption

Consider some simple class that check access rules:

class Sample {
    private int ACCESS_ALLOWED = 1;

    public boolean giveAccess() {
        return 42 == ACCESS_ALLOWED;
    }
}

For clients, it always returns false.

Sample sample = new Sample();
sample.giveAccess();   // false, no access

// bypass
Unsafe unsafe = getUnsafe();
Field f = sample.getClass().getDeclaredField("ACCESS_ALLOWED");
unsafe.putInt(sample, unsafe.objectFieldOffset(f), 42); // memory corruption

sample.giveAccess(); // true, access granted

Now all clients will get unlimited access.

sizeOf

Much simpler sizeOf can be achieved if we just read size value from the class struct for this object.

public static long sizeOf(Object object){
    return getUnsafe().getAddress(
        normalize(getUnsafe().getInt(object, 4L)) + 12L);
}
private static long normalize(int value) {
    if(value >= 0) return value;
    return (~0L >>> 32) & value;
}

normalize is a method for casting signed int to unsigned long, for correct address usage.

Shallow copy

Having implementation of calculating shallow object size, we can simply add function that copy objects. 

static Object shallowCopy(Object obj) {
    long si= sizeOf(obj);
    long st = toAddress(obj);
    long addr= getUnsafe().allocateMemory(si);
    getUnsafe().copyMemory(st, addr, si);
    return fromAddress(addr);
}

toAddress and fromAddress convert object to its address in memory and vice versa.

static long toAddress(Object obj) {
    Object[] array = new Object[] {obj};
    long baseOffset = getUnsafe().arrayBaseOffset(Object[].class);
    return normalize(getUnsafe().getInt(array, baseOffset));
}

static Object fromAddress(long address) {
    Object[] array = new Object[] {null};
    long baseOffset = getUnsafe().arrayBaseOffset(Object[].class);
    getUnsafe().putLong(array, baseOffset, address);
    return array[0];
}

This copy function can be used to copy object of any type, its size will be calculated dynamically. After copying you need to cast object to specific type.

Hide Password

One more interesting usage of direct memory access in Unsafe is removing unwanted objects from memory.

Most of the APIs for retrieving user's password, have signature as byte[] or char[]. Why arrays?

It is completely for security reason, because we can nullify array elements after we don't need them. If we retrieve password as String it can be saved like an object in memory and nullifying that string just perform dereference operation. This object still in memory by the time GC decide to perform cleanup.

String password = new String("vim100@my##$$");
String fake = new String(password.replaceAll(".", "?"));
System.out.println(password); // viml00k@my##$$
System.out.println(fake); // ????????????

getUnsafe().copyMemory(
          fake, 0L, null, toAddress(password), sizeOf(password));

System.out.println(password); // ????????????

System.out.println(fake); // ????????????

More Safe

Field stringValue = String.class.getDeclaredField("value");
stringValue.setAccessible(true);
char[] mem = (char[]) stringValue.get(password);
for (int i=0; i < mem.length; i++) {
  mem[i] = '?';
}

Multiple Inheritance

here is no multiple inheritance in java.

Correct, except we can cast every type to every another one, if we want.

long intClassAddress = normalize(getUnsafe().getInt(new Integer(0), 4L));
long strClassAddress = normalize(getUnsafe().getInt("", 4L));
getUnsafe().putAddress(intClassAddress + 36, strClassAddress);

This snippet adds String class to Integer superclasses, so we can cast without runtime exception.

(String) (Object) (new Integer(666))

One problem that we must do it with pre-casting to object. To cheat compiler.

Dynamic classes


We can create classes in runtime, for example from compiled .class file. To perform that read class contents to byte array and pass it properly to defineClass method.

byte[] classContents = getClassContent();
Class c = getUnsafe().defineClass(
              null, classContents, 0, classContents.length);
    c.getMethod("a").invoke(c.newInstance(), null); // 1



private static byte[] getClassContent() throws Exception {
    File f = new File("/home/mishadoff/tmp/A.class");
    FileInputStream input = new FileInputStream(f);
    byte[] content = new byte[(int)f.length()];
    input.read(content);
    input.close();
    return content;
}

Throw an Exception

getUnsafe().throwException(new IOException());

This method throws checked exception, but your code not forced to catch or rethrow it. Just like runtime exception.


Fast Serialization

Everyone knows that standard java Serializable capability to perform serialization is very slow. It also require class to have public non-argument constructor.

Externalizable is better, but it needs to define schema for class to be serialized.

Serialization:

    Build schema for object using reflection. It can be done once for class.
    Use Unsafe methods getLong, getInt, getObject, etc. to retrieve actual field values.
    Add class identifier to have capability restore this object.
    Write them to the file or any output.

You can also add compression to save space.

Deserialization:

    Create instance of serialized class. allocateInstance helps, because does not require any constructor.
    Build schema. The same as 1 step in serialization.
    Read all fields from file or any input.
    Use Unsafe methods putLong, putInt, putObject, etc. to fill the object.

This serialization will be really fast.

Kryo is example.

Big Arrays

As you know Integer.MAX_VALUE constant is a max size of java array. Using direct memory allocation we can create arrays with size limited by only heap size.

class SuperArray {
    private final static int BYTE = 1;

    private long size;
    private long address;

    public SuperArray(long size) {
        this.size = size;
        address = getUnsafe().allocateMemory(size * BYTE);
    }

    public void set(long i, byte value) {
        getUnsafe().putByte(address + i * BYTE, value);
    }

    public int get(long idx) {
        return getUnsafe().getByte(address + idx * BYTE);
    }

    public long size() {
        return size;
    }
}


And sample usage:

long SUPER_SIZE = (long)Integer.MAX_VALUE * 2;
SuperArray array = new SuperArray(SUPER_SIZE);
System.out.println("Array size:" + array.size()); // 4294967294
for (int i = 0; i < 100; i++) {
    array.set((long)Integer.MAX_VALUE + i, (byte)3);
    sum += array.get((long)Integer.MAX_VALUE + i);
}
System.out.println("Sum of 100 elements:" + sum);  // 300

Memory allocated this way not located in the heap and not under GC management, so take care of it using Unsafe.freeMemory(). It also does not perform any boundary checks, so any illegal access may cause JVM crash.


Concurrency

And few words about concurrency with Unsafe. compareAndSwap methods are atomic and can be used to implement high-performance lock-free data structures.

For example, consider the problem to increment value in the shared object using lot of threads.

First we define simple interface Counter:

interface Counter {
    void increment();
    long getCounter();
}

Then we define worker thread CounterClient, that uses Counter:

class CounterClient implements Runnable {
    private Counter c;
    private int num;

    public CounterClient(Counter c, int num) {
        this.c = c;
        this.num = num;
    }

    @Override
    public void run() {
        for (int i = 0; i < num; i++) {
            c.increment();
        }
    }
}

And this is testing code:

int NUM_OF_THREADS = 1000;
int NUM_OF_INCREMENTS = 100000;
ExecutorService service = Executors.newFixedThreadPool(NUM_OF_THREADS);
Counter counter = ... // creating instance of specific counter
long before = System.currentTimeMillis();
for (int i = 0; i < NUM_OF_THREADS; i++) {
    service.submit(new CounterClient(counter, NUM_OF_INCREMENTS));
}
service.shutdown();
service.awaitTermination(1, TimeUnit.MINUTES);
long after = System.currentTimeMillis();
System.out.println("Counter result: " + c.getCounter());
System.out.println("Time passed in ms:" + (after - before));

First implementation is not-synchronized counter:

class StupidCounter implements Counter {
    private long counter = 0;

    @Override
    public void increment() {
        counter++;
    }

    @Override
    public long getCounter() {
        return counter;
    }
}

Output:

Counter result: 99542945
Time passed in ms: 679

Working fast, but no threads management at all, so result is inaccurate.

Second attempt, add easiest java-way synchronization:

class SyncCounter implements Counter {
    private long counter = 0;

    @Override
    public synchronized void increment() {
        counter++;
    }

    @Override
    public long getCounter() {
        return counter;
    }
}

Output:

Counter result: 100000000
Time passed in ms: 10136

Radical synchronization always work. But timings is awful.

ReentrantReadWriteLock:

class LockCounter implements Counter {
    private long counter = 0;
    private WriteLock lock = new ReentrantReadWriteLock().writeLock();

    @Override
    public void increment() {
        lock.lock();
        counter++;
        lock.unlock();
    }

    @Override
    public long getCounter() {
        return counter;
    }
}

Output:

Counter result: 100000000
Time passed in ms: 8065

Still correct, and timings are better.

About atomics.

class AtomicCounter implements Counter {
    AtomicLong counter = new AtomicLong(0);

    @Override
    public void increment() {
        counter.incrementAndGet();
    }

    @Override
    public long getCounter() {
        return counter.get();
    }
}

Output:

Counter result: 100000000
Time passed in ms: 6552

AtomicCounter is even better.

Unsafe primitive compareAndSwapLong to see if it is really privilegy to use it.

class CASCounter implements Counter {
    private volatile long counter = 0;
    private Unsafe unsafe;
    private long offset;

    public CASCounter() throws Exception {
        unsafe = getUnsafe();
        offset = unsafe.objectFieldOffset(CASCounter.class.getDeclaredField("counter"));
    }

    @Override
    public void increment() {
        long before = counter;
        while (!unsafe.compareAndSwapLong(this, offset, before, before + 1)) {
            before = counter;
        }
    }

    @Override
    public long getCounter() {
        return counter;
    }

Output:

Counter result: 100000000
Time passed in ms: 6454

Seems equal to atomics. Because Atomics use Unsafe.

In fact this example is easy enough, but it shows some power of Unsafe.