Your code is neat, variables are well named, the input validation is there, etc. It's all good. The assessment of your time complexity being \$O(\log_2{n})\$ is also right.
But, is that as good as it can be? Well, your time-complexity assessment is a hint as to what's a better solution... The base-2 log is also an indication of the number of bits used. Remember, in base 2, the number of bits needed increases at the exponential of 2 as well.
As a consequence, your function could be reduced to \$O(1)\$ with:
import math
def count_bitx(num):
assert num >= 0
if num == 0:
return 0
return 1 + int(math.log(num, 2))
Note that Python 3.1 introduced the bit_length()
method, so you could do:
def count_bits(num):
assert num >= 0
return num.bit_length();