My idea is based on merge 2 sorted list. If we merge every adjacent 2 lists per iteration, we needs log2(k) iterations, where merge 2 sorted list costs O(n) per iteration. Thus the run time is O(log(k)n).
public class Solution {
public ListNode mergeKLists(ArrayList<ListNode> lists) {
// IMPORTANT: Please reset any member data you declared, as
// the same Solution instance will be reused for each test case.
if(lists.isEmpty()) return null;
if(lists.size() == 1) return lists.get(0);
int k = lists.size();
int log = (int)(Math.log(k)/Math.log(2));
log = log < Math.log(k)/Math.log(2)? log+1:log; // take ceiling
for(int i = 1; i <= log; i++){
for(int j = 0; j < lists.size(); j=j+(int)Math.pow(2,i)){
int offset = j+(int)Math.pow(2,i-1);
lists.set(j, mergeTwoLists(lists.get(j), (offset >= lists.size()? null : lists.get(offset))));
}
}
return lists.get(0);
}
public ListNode mergeTwoLists(ListNode l1, ListNode l2) {
// IMPORTANT: Please reset any member data you declared, as
// the same Solution instance will be reused for each test case.
if(l1 == null) return l2;
if(l2 == null) return l1;
ListNode head = l1.val > l2.val? l2:l1;
if(head.equals(l2)){
l2 = l1;
l1 = head;
}
while(l1.next != null && l2 != null){
if(l1.next.val > l2.val){
ListNode tmp = l1.next;
l1.next = l2;
l2 = l2.next;
l1 = l1.next;
l1.next = tmp;
}
else
l1 = l1.next;
}
if(l2 != null){
l1.next = l2;
}
return head;
}
}