Leetcode 笔记(1): Sqrt(x)

http://oj.leetcode.com/problems/sqrtx/

这一题看上去很简单,在不用std::sqrt的前提下,最容易想到的就是从1开始brute force,但这是O(n)的复杂度,明显行不通。 要找O(lg n) 的算法,我的思路是从1 开始直接保持double, 当平方超过x时就以double之前的值作为base,然后添加一个增量,增量持续double直到超出。于是有了以下的code:

class Solution {
public:

    inline int sq(int x) { return x * x; }
    int sqrt(int x) {
        int base = 0;
        while (sq(base) < x)
        {
            if (sq(base + 1) > x)
                break;

            int i = 1;
            while (sq(base + i) < x)
                i *= 2;
            if (sq(base + i) == x)
                return base + i;
            base += i/2;
        }
        return base;
    }
};

这段code在输入2147395599时出现了Time Limit Exceeded。在没有仔细分析的情况下,我以为是算法复杂度太高,放弃继续修改,转而寻找参考。 实际上,这里是发生了溢出,int sq(int x) 返回的是x * x, 在x足够大的情况下,会发生 sq(x)溢出继而被判断为负值,进入死循环。

在去掉sq(x)函数而改用 判断x/a  vs a之后,代码可以通过

class Solution {
public:

    int sqrt(int x) {
        if (x <= 1)
            return x;
        int base = 1;
        while (base < x / base)
        {
            if (base + 1 > x / (base + 1))
                break;

            int i = 1;
            while (base + i < x / (base + i))
                i *= 2;
            if (base + i == x / (base + i))
                return base + i;
            base += i/2;
        }
        return base;
    }
};

OJ的discuss页面有人提供了可行的方案: 包括数学的方法和更程序员式的二分法。

我的二分法实现出现了诸多问题,一个中间版本如下:

class Solution {
public:

    inline int sq(int x) { return x * x; }
    int sqrt(int x) {
        int l = 0; 
        int r = x;
        while (l < r)
        {
            int m = (l + r) / 2;

            if (m * m == x)
                return m;
            if (m * m < x)
            {
                if ((m+1) * (m+1) == x)
                    return m + 1;
                else if ((m+1) * (m+1) > x)       
                    return m;
            }   
            if (m * m > x)
                r = m - 1;
            else
                l = m + 1;
        }
        return l;
    }
};

这里有逻辑上和实现上的诸多问题,除了上面提到的sq(x),还有一个就是

 int m = (l + r) / 2;

对于两个int,求平均值防止溢出的安全计算应该是

 int m = l + (l - r) / 2;

在参考着原作者的代码下,终于将这一实现的思路理清;

class Solution {
public:
    int sqrt(int x) {
        if (x <= 1)
            return x;

        int l = 1; 
        int r = x;
        while (r - l > 1)
        {
            int m = l + (r - l) / 2;
            if (x / m == m)
                return m;
            if (x / m < m)
                r = m;
            else
                l = m;
        }
        return l;
    }
};

首先,对于输入为0和1的情况进行处理,从而保证 r  – l > 1  且结果在(l, r)这一区间

其次,进行二分查找,在r – l > 1的情况下,m != l && m != r,从而只需要判断m是否是平方根,如果不是,则平方根落在 (l, m)之间,反之则在(m, r)

二分查找结束,说明 r – l <= 1, 由于循环过程中m != l && m != r,所以结束时 r = l +1, 说明平方根严格落在(l, r)区间,则返回l

 

Leave a Reply

Your email address will not be published. Required fields are marked *