package org.argeo.jjml.llm;

import java.nio.ByteBuffer;

/** Access to a serialized context state. */
public interface LlamaCppContextState {
	void save(LlamaCppContext context, int contextPosition);

	int load(LlamaCppContext context);

	/** Serialized context state based on a {@link ByteBuffer}. */
	static class ByteBufferSavedState implements LlamaCppContextState {
		private ByteBuffer savedState;
		private int savedContextPosition;

		@Override
		public void save(LlamaCppContext context, int contextPosition) {
			int stateSize = (int) context.getStateSize();
			//savedState = ByteBuffer.allocate(stateSize);
			savedState = ByteBuffer.allocateDirect(stateSize);
			context.readState(savedState);
			// System.out.println("Saved context state (" + stateSize / (1024 * 1024) + "
			// MiB)");
			savedContextPosition = contextPosition;

		}

		@Override
		public int load(LlamaCppContext context) {
			savedState.flip();
			context.writeState(savedState);
			return savedContextPosition;
		}

	}
}
