Do you feel DP is hard? Let's destroy it! Python ONLY

Do you feel DP is hard to understand? Are you struggele to understand the meaing for each demision of a dp array?
Let's forget about it and just concentrate on the fun bit.

Recently I'm trying to use @cache in my python code and I found it is a perfect way to help me to write understandable DP code.
People who just starts to learn DP may have the same issues like following:

  • I understand the problem and I know what is the state transation equation, but how to write the for loop to generate the dp array?
  • I know where the edge case is, but I'm struggling to set the +1 / -1 boundary for the array

Well, I had the same issues until I started using @cache in may code. You just need to follow the 4 steps.

  1. Remember what problem you are trying to solve
  2. Describe the sub problem using state transation equation (don't forget including edge cases)
  3. Wrap the function with @cache
  4. That's it!

Technically, this is Recursion + Memoization, hope this trick will help you to have a better start with learning DP when you don't know how to setup dp array.

OK, let's look at an example.
1335. Minimum Difficulty of a Job Schedule

  1. Key problem:
    Split n jobs into d days, sum max difficulty job each day
  2. State transation equation:
    At d-1 day, you should finish k jobs and leave the rest to the last day
    dp(n,d) = dp(n,k) + max(k,n)
    where k is between d-1 (at least one job per day in last d-1 days) and n-1 (leave 1 job for the last day)
  3. Because sub problems always have the same result, we should be able to memorize them with @cache
  4. That's it, here is the full code
def minDifficulty(self, jobs: List[int], d: int) -> int:
        if len(jobs) < d: 
            return -1
			
        @cache
        def maxR(i, j):
            return max(jobs[i:j])
        
        @cache
        def dp(n, d):
            ans = float("inf")
            if d == 1:
                return maxR(0, n)            
            for k in range(d-1, n):
                ans = min(ans, dp(k, d-1) + maxR(k,n))            
            return ans
        
        n = len(jobs)        
        return dp(n,d)

Another example:
221. Maximal Square

  1. Key problem:
    Find biggest square with 1s
  2. State transation equation:
    From top left to bottom right, count countiuse 1s, for each cell, if it is 1, check its top, left, top-left and pick the min.
    At position (i, j)
    dp(i, j) = min(dp(i-1, j), dp(i, j-1), dp(i-1, j-1)) + 1
  3. Sub problems need to consider rows n and colums m, so we need to pass them to @cache as well
  4. That's it, here is the full code
def maximalSquare(self, matrix: List[List[str]]) -> int:        
        @cache
        def dp(i, j, n, m):
            if i == 0 or j == 0:
                return int(matrix[i][j])  
            if matrix[i][j] == '0':
                return 0            
            return min(dp(i-1,j-1,n,m), dp(i-1,j,n,m), dp(i,j-1,n,m)) + 1            
        
        n = len(matrix)
        m = len(matrix[0])
        ans = 0
        
        for i in range(n):            
            for j in range(m):                
                ans = max(ans, dp(i,j,n,m))
            
        return ans * ans
Comments (2)