Complete-Preparation

πŸŽ‰ One-stop destination for all your technical interview Preparation πŸŽ‰

View the Project on GitHub

Chocolate Pickup 🌟🌟🌟

This problem is similar to leetcode cherry pickup II

Conventions

Recursive Solution

Code

int f(vector<vector<int>>& grid, int r, int c, int i, int aj, int bj)
{
    if (aj < 0 || bj < 0 || aj >= c || bj >= c) return INT_MIN;

    if (i == r - 1) {
        if (aj == bj) return grid[i][aj];
        return grid[i][aj] + grid[i][bj];
    }

    int mx = INT_MIN;
    for (int dirAj = -1; dirAj <= 1; dirAj++) {
        for (int dirBj = -1; dirBj <= 1; dirBj++) {
            int val = 0;
            if (aj == bj)
                val = grid[i][aj];
            else
                val = grid[i][aj] + grid[i][bj];

            val += f(grid, r, c, i + 1, aj + dirAj, bj + dirBj);
            mx = max(mx, val);
        }
    }

    return mx;
}
int maximumChocolates(int r, int c, vector<vector<int>>& grid)
{
    return f(grid, r, c, 0, 0, c - 1);
}

Memoization

Code

int f(vector<vector<int>>& grid, int r, int c, int i, int aj, int bj, vector<vector<vector<int>>>& dp)
{
    if (aj < 0 || bj < 0 || aj >= c || bj >= c)
        return INT_MIN;

    if (i == r - 1) {
        if (aj == bj)
            return grid[i][aj];
        return grid[i][aj] + grid[i][bj];
    }

    if (dp[i][aj][bj] != -1)
        return dp[i][aj][bj];

    int mx = INT_MIN;
    for (int dirAj = -1; dirAj <= 1; dirAj++) {
        for (int dirBj = -1; dirBj <= 1; dirBj++) {
            int val = 0;
            if (aj == bj)
                val = grid[i][aj];
            else
                val = grid[i][aj] + grid[i][bj];

            val += f(grid, r, c, i + 1, aj + dirAj, bj + dirBj, dp);
            mx = max(mx, val);
        }
    }

    return dp[i][aj][bj] = mx;
}
int maximumChocolates(int r, int c, vector<vector<int>>& grid)
{
    vector<vector<vector<int>>> dp(r, vector<vector<int>>(c, vector<int>(c, -1)));
    return f(grid, r, c, 0, 0, c - 1, dp);
}

Tabulation

Code

int maximumChocolates(int r, int c, vector<vector<int>>& grid)
{
    vector<vector<vector<int>>> dp(r, vector<vector<int>>(c, vector<int>(c, 0)));
    for (int aj = 0; aj < c; aj++) {
        for (int bj = 0; bj < c; bj++) {
            if (aj == bj)
                dp[r - 1][aj][bj] = grid[r - 1][aj];
            else
                dp[r - 1][aj][bj] = grid[r - 1][aj] + grid[r - 1][bj];
        }
    }

    for (int i = r - 2; i >= 0; i--) {
        for (int aj = 0; aj < c; aj++) {
            for (int bj = 0; bj < c; bj++) {
                int mx = INT_MIN;
                for (int dirAj = -1; dirAj <= 1; dirAj++) {
                    for (int dirBj = -1; dirBj <= 1; dirBj++) {
                        int val = 0;
                        if (aj == bj) val = grid[i][aj];
                        else val = grid[i][aj] + grid[i][bj];

                        // check for out of bound
                        if (aj + dirAj >= 0 && aj + dirAj < c && bj + dirBj >= 0 && bj + dirBj < c)
                            val += dp[i + 1][aj + dirAj][bj + dirBj];
                        else
                            val += INT_MIN;

                        mx = max(mx, val);
                    }
                }
                dp[i][aj][bj] = mx;
            }
        }
    }
    return dp[0][0][c - 1];
}

Space optimization

Code

int maximumChocolates(int r, int c, vector<vector<int>>& grid)
{
    vector<vector<int>> front(c,vector<int>(c,0)), curr(c,vector<int>(c,0));

    for (int aj = 0; aj < c; aj++) {
        for (int bj = 0; bj < c; bj++) {
            if (aj == bj)
                front[aj][bj] = grid[r - 1][aj];
            else
                front[aj][bj] = grid[r - 1][aj] + grid[r - 1][bj];
        }
    }

    for (int i = r - 2; i >= 0; i--) {
        for (int aj = 0; aj < c; aj++) {
            for (int bj = 0; bj < c; bj++) {
                int mx = INT_MIN;
                for (int dirAj = -1; dirAj <= 1; dirAj++) {
                    for (int dirBj = -1; dirBj <= 1; dirBj++) {
                        int val = 0;
                        if (aj == bj) val = grid[i][aj];
                        else val = grid[i][aj] + grid[i][bj];

                        if (aj + dirAj >= 0 && aj + dirAj < c && bj + dirBj >= 0 && bj + dirBj < c)
                            val += front[aj + dirAj][bj + dirBj];
                        else
                            val += INT_MIN;

                        mx = max(mx, val);
                    }
                }
                curr[aj][bj] = mx;
            }
        }
        front = curr;
    }
    return front[0][c - 1];
}