TakeScheduler.java

package org.microspace.event;

import java.math.BigInteger;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.microspace.exception.IllegalOperationException;
import org.microspace.space.MicroSpace;
import org.microspace.table.AtomicBigInteger;
import org.microspace.table.column.Accessor;
import org.microspace.table.column.GetSetPair;
import org.microspace.thread.MicroSpaceThreadFactory;
import org.microspace.util.AccessorCache;
import org.microspace.util.MicroLogger;

/**
 * Execute task in a threadpool.
 * 
 * @author Gaspar Sinai - {@literal gaspar.sinai@microspace.org}
 * @version 2016-06-26
 * @param <T>
 *            is the type of the ThreadId.
 */
class TakeScheduler<T> extends Thread {

	private final MicroSpace space;

	private static final MicroLogger log = new MicroLogger(TakeScheduler.class);

	private final static int DEFAULT_THREAD_POOL_SIZE = 32;

	private ThreadLocal<T> currentThreadId = new ThreadLocal<T>();

	/* These two will be assigned in assignLock */
	private final IncomingQueue<T> incomingQueue = new IncomingQueue<>();

	private final MergeMap<T> mergeMap = new MergeMap<>();

	private final Map<T, TakeTask<T, ?>> assignedMap = new HashMap<T, TakeTask<T, ?>>(512, 0.6f);

	private final Set<T> lockedThreads = new HashSet<T>(512, 0.6f);

	private final Object assignLock = new Object();
	private final ExecutorService executorService;
	private final int threadPoolSize;
	private final AccessorCache accessorCache;
	boolean started = false;

	private static AtomicInteger instance = new AtomicInteger();
	private static AtomicBigInteger sequenceCountter = new AtomicBigInteger(BigInteger.ZERO);
	private final MicroSpaceThreadFactory threadFactory;

	CountDownLatch countDownLatch;

	private boolean suspendProcessing = false;

	public TakeScheduler(MicroSpace space, int threadPoolSize) {
		this(space, threadPoolSize, "takescheduler-" + instance.incrementAndGet());
	}

	private TakeScheduler(MicroSpace space, int threadPoolSize, String name) {
		super(name + "-assigner");
		this.space = space;
		this.threadPoolSize = threadPoolSize;

		this.threadFactory = new MicroSpaceThreadFactory(name + "-task");
		this.executorService = Executors.newFixedThreadPool(threadPoolSize, threadFactory);
		this.accessorCache = new AccessorCache(space.getAccessorGenerator());
	}

	public TakeScheduler(MicroSpace space) {
		this(space, DEFAULT_THREAD_POOL_SIZE);
	}

	public T getCurrentThreadId() {
		return currentThreadId.get();
	}

	/**
	 * Schedule a message taken by TakeInterest.
	 * 
	 * @param ti
	 *            The take interest.
	 * @param message
	 *            The message
	 * @return true on successful commit.
	 */
	public <M> boolean schedule(TakeInterest<T, M> ti, M message) {
		TakeTask<T, M> task = new TakeTask<T, M>(ti, message, sequenceCountter.incrementAndGet());
		if (task.getThreadId() == null) {
			log.error("Message does not have ThreadId. It will be discarded $*", null, message);
			return false;
		}
		synchronized (assignLock) {
			TakeTask<T, ?> oldTask = mergeMap.get(task);
			if (oldTask != null) {
				if (!oldTask.getThreadId().equals(task.getThreadId())) {
					log.error("Different ThreadId $* != $* Same Message ID $*.", null, oldTask.getThreadId(),
							task.getThreadId(), task.getMessageId());
					return false;
				}

				if (!task.isMergeMessages()) {
					log.error("Same Message ID $* for Not Mergable Tasks.", null, task.getMessageId());
					return false;
				}
			}
			mergeMap.put(task);
			if (oldTask == null) {
				incomingQueue.add(task);
			} else if (incomingQueue.getHeadSize() == 0) {
				log.error("Recovering from internal error. oldTask=$* newTask=$*", null,
					oldTask.getMessage(), task.getMessage());
				incomingQueue.add(task);
			}
			assignLock.notifyAll();
		}
		return true;
	}

	@Override
	public void start() {
		if (started) {
			throw new IllegalOperationException("Service already started");
		}
		countDownLatch = new CountDownLatch(1);
		started = true;
		super.start();
	}

	@Deprecated
	public void forceShutdown() {
		if (started) {
			log.info("Shutting down executor service.");
			executorService.shutdown();
		}

		synchronized (assignLock) {
			started = false;
		}
		interrupt();
		try {
			boolean finished = executorService.awaitTermination(10, TimeUnit.SECONDS);
			if (finished) {
				threadFactory.shutdown();
			} else {
				log.error("Can not do clean thread shutdown.", null);
			}
		} catch (Exception e) {
			log.error("Can not do clean thread shutdown.", e);
		}

	}

	public void suspendMessaging(long timeoutMillisecs) {
		synchronized (assignLock) {
			suspendProcessing = true;
		}
		try {
			List<TakeTask<T, ?>> assignedTasks = getAssignedTasks();
			long start = System.currentTimeMillis();
			long remainingTime = timeoutMillisecs;
			while (remainingTime > 0) {
				synchronized (assignLock) {
					assignedTasks = new LinkedList<TakeTask<T, ?>>();
					assignedTasks.addAll(assignedMap.values());
					if (assignedTasks.size() == 0) break;
					// Ignore self.
					if (assignedTasks.size() == 1) {
						TakeTask<T, ?> task = assignedTasks.get(0);
						if (getCurrentThreadId() != null && task.getThreadId().equals(getCurrentThreadId())) {
							break;
						}
					}
					assignLock.wait(remainingTime);
					long now = System.currentTimeMillis();
					if (now < start) start = now;
					if (now - start >= timeoutMillisecs) {
						break;
					}
					remainingTime = timeoutMillisecs - (now - start);
				}
			}
			if (assignedTasks.size() == 0) {
				return;
			}
			if (assignedTasks.size() == 1) {
				TakeTask<T, ?> task = assignedTasks.get(0);
				if (getCurrentThreadId() != null && task.getThreadId().equals(getCurrentThreadId())) {
					return;
				}
			}
			log.error("suspendProcessing giving up after $* milliseconds for $* tasks to finish.", null,
							timeoutMillisecs, assignedTasks.size());
			
			int taskCount = 0;
			for (TakeTask<T, ?> task : assignedTasks) {
					log.warn("task[$*] message=[$*]", null, taskCount++, task.getMessage());
			}
		}
		catch (RuntimeException re) {
			unsuspendMessaging();
			log.warn("suspendProcessing error.", re);
			throw (re);
		}
		catch (Throwable th) {
			unsuspendMessaging();
			log.warn("suspendProcessing error.", th);
		}
	}

	public void unsuspendMessaging() {
		synchronized (assignLock) {
			suspendProcessing = false;
			assignLock.notifyAll();
		}
	}

	public void shutdown() {
		if (!started)
			return;

		synchronized (assignLock) {
			started = false;
		}
		interrupt();
		try {
			countDownLatch.await();
		} catch (InterruptedException e1) {
		}
		log.info("Shutting down executor service.");
		List<TakeTask<T, ?>> assignedTasks = getAssignedTasks();
		while (assignedTasks.size() != 0) {
			log.warn("Waiting for $* tasks to finish.", null, assignedTasks.size());
			int count = 0;
			for (TakeTask<T, ?> task : assignedTasks) {
				log.warn("task[$*] message=[$*]", null, count++, task.getMessage());
			}
			try {
				Thread.sleep(10000L);
			} catch (InterruptedException e) {
			}
			assignedTasks = getAssignedTasks();
		}
		executorService.shutdown();
		try {
			boolean finished = executorService.awaitTermination(10, TimeUnit.SECONDS);
			if (finished) {
				threadFactory.shutdown();
			} else {
				log.error("Can not do clean thread shutdown.", null);
			}
		} catch (Exception e) {
			log.error("Can not do clean thread shutdown.", e);
		}

	}

	@Override
	public void run() {
		started = true;
		while (started) {
			try {
				synchronized (assignLock) {
					assignTasks();
					if (isInterrupted()) {
						// log.warn("Assign-thread interrupted.", null);
						break;
					}
					assignLock.wait();
				}
			} catch (InterruptedException e) {
				// log.warn("Assign-thread exiting.", e);
				break;
			}
		}
		started = false;
		synchronized (assignLock) {
			incomingQueue.clear();
			mergeMap.clear();
			lockedThreads.clear();
		}
		countDownLatch.countDown();
	}

	/**
	 * Thread safe remove.
	 * 
	 * @param task
	 *            The task to be removed after it completed.
	 */
	public void taskCompleted(TakeTask<T, ?> task) {
		synchronized (assignLock) {
			// System.err.println ("TaskCompleted " + task.getThreadId());
			assignedMap.remove(task.getThreadId());
			assignLock.notifyAll();
		}
	}

	/**
	 * Thread safe markProcessed to remove processed messages from queue. If
	 * ThreadId is not specified, or does not match do not remove.
	 * 
	 * @param messages
	 *            The messages to be removed by their Id if ThreaId matches.
	 */
	@SuppressWarnings("unchecked")
	public <M> void removeUnassignedTasksByIdOf(List<M> messages) {
		if (messages.size() == 0) {
			return;
		}
		M m0 = messages.get(0);
		Accessor<M> accessor = accessorCache.get((Class<M>) m0.getClass());
		// We may miss an important update.
		if (!accessor.isUpdatableRecord()) {
			// log.info("removeUnassignedTasksByIdOf not updatable $*", m0);
			// return;
		}
		GetSetPair<M> idGetter = accessor.getPrimaryKeyGetSetPair();
		GetSetPair<M> threadIdGetter = accessor.getThreadIdGetSetPair();
		synchronized (assignLock) {
			for (M m : messages) {
				if (threadIdGetter.get(m) == null) {
					log.error("removeUnassignedTasksByIdOf with null ThreadId $*", null, m);
					continue;
				}
				Object id = idGetter.get(m);
				TakeTask<T, ?> old = mergeMap.get(m.getClass(), id);

				if (old == null) {
					// log.info("removeUnassignedTasksByIdOf already removed $*", m);
					continue;
				}

				if (!threadIdGetter.get(m).equals(old.getThreadId())) {
					log.error("removeUnassignedTasksByIdOf with different ThreadId $*", null, m);
					continue;
				}
				mergeMap.remove(m.getClass(), id);
			}
		}
	}

	/**
	 * Obtain the tasks size that are assigned to a thread and being executed.
	 * 
	 * @return Assigned set size.
	 */
	public int getAssignedSize() {
		synchronized (assignLock) {
			return assignedMap.size();
		}
	}

	/**
	 * Obtain the tasks size that have not been assigned yet to a thread.
	 * 
	 * @return Assigned set size.
	 */
	public int getUnassignedSize() {
		synchronized (assignLock) {
			return mergeMap.size();
		}
	}

	/**
	 * Obtain the tasks size that are assigned to a thread and being executed.
	 * 
	 * @return All the tasks that are being executed.
	 */
	public List<TakeTask<T, ?>> getAssignedTasks() {
		LinkedList<TakeTask<T, ?>> ret = new LinkedList<TakeTask<T, ?>>();
		synchronized (assignLock) {
			ret.addAll(assignedMap.values());
			return (ret);
		}
	}


	/**
	 * Obtain the locked threads.
	 * 
	 * @return The locked threads.
	 */
	public Set<T> getLockedThreads() {
		Set<T> ret = new HashSet<>();
		synchronized (assignLock) {
			ret.addAll(lockedThreads);
			return (ret);
		}
	}

	/**
	 * Lock execution of a thread.
	 * 
	 * @param thread
	 *            The thread to lock.
	 */
	public void lockThread(T thread) {
		synchronized (assignLock) {
			lockedThreads.add(thread);
		}
	}

	/**
	 * Unlock execution of a thread.
	 * 
	 * @param thread
	 *            The thread.
	 */
	public void unlockThread(T thread) {
		synchronized (assignLock) {
			lockedThreads.remove(thread);
			assignLock.notifyAll();
		}
	}

	/**
	 * Obtain the tasks that have not been assigned yet to a thread.
	 * 
	 * @return All the tasks that has not been aasigned.
	 */
	public List<TakeTask<T, ?>> getUnassignedTasks() {
		LinkedList<TakeTask<T, ?>> ret = new LinkedList<TakeTask<T, ?>>();
		synchronized (assignLock) {
			for (TakeTask<T, ?> origin : incomingQueue) {
				ret.addAll(incomingQueue.getList(origin));
			}
			return (ret);
		}
	}
	/**
	 * Obtain the tasks size that are pending for current thread.
	 * @return All the tasks that are pending for current thread.
	 */
	public List<TakeTask<T, ?>> getUnassignedTasksForThread(T threadId) {
		LinkedList<TakeTask<T, ?>> ret = new LinkedList<TakeTask<T, ?>>();
		if (threadId == null) {
			return ret;
		}
		synchronized (assignLock) {
			List<TakeTask<T, ?>> hl =  incomingQueue.getTasksByThread(threadId);
			if (hl != null) {
				ret.addAll(hl);
			}
			return ret;
		}
	}
	
	/**
	 * Assign a task. Callable in synchronized (assignLock).
	 * 
	 */
	private void assignTasks() {
		if (suspendProcessing) {
			return;
		}
		Iterator<TakeTask<T, ?>> it = incomingQueue.iterator();
		LinkedList<TakeTask<T, ?>> removedList = new LinkedList<>();
		while (it.hasNext() && assignedMap.size() < threadPoolSize) {
			TakeTask<T, ?> origin = it.next();
			if (assignedMap.containsKey(origin.getThreadId())) {
				continue;
			}
			if (lockedThreads.contains(origin.getThreadId())) {
				continue;
			}
			removedList.add(origin);
			TakeTask<T, ?> next = incomingQueue.removeNext(origin);
			if (next == null) {
				throw new InternalError("incomingList next is null");
			}
			TakeTask<T, ?> task = mergeMap.remove(next);
			while (task == null) {
				next = incomingQueue.removeNext(origin);
				if (next == null)
					break;
				task = mergeMap.remove(next);
			}
			if (task == null)
				continue;

			task.setAssignedTime(System.currentTimeMillis());
			assignedMap.put(task.getThreadId(), task);
			TakeTaskExecutor executor = new TakeTaskExecutor(task);
			executorService.submit(executor);
		}
		for (TakeTask<T, ?> origin : removedList) {
			incomingQueue.recalculate(origin);
		}
	}

	class TakeTaskExecutor implements Runnable {
		TakeTask<T, ?> takeTask;

		public TakeTaskExecutor(TakeTask<T, ?> takeTask) {
			this.takeTask = takeTask;
		}

		@Override
		public void run() {
			try {
				currentThreadId.set(takeTask.getThreadId());
				if (takeTask.isPerformTake()) {
					// tricky client may have read ahead the queue.
					if (space.takeByIdOf(takeTask.getMessage()) != null) {
						takeTask.handleTake();
						space.commit();
					}
				} else {
					takeTask.handleTake();
					space.commit();
				}

			} catch (Throwable e) {
				log.error("Message processor client error. Message = $0, Template = $1, Will be removed=$2", e,
						takeTask.getMessage(), takeTask.getQuery(), takeTask.isPerformTake());
				try {
					if (takeTask.isPerformTake()) {
						space.rollback();
						space.takeByIdOf(takeTask.getMessage());
					}
					space.commit();

				} catch (Throwable e2) {
					log.error("Message processor remove error1. Message = $0, Template = $1, Will be removed=$2", e2,
							takeTask.getMessage(), takeTask.getQuery(), takeTask.isPerformTake());
				}
			}
			currentThreadId.set(null);
			taskCompleted(takeTask);
		}
	}
}