Calculate Maximum Information Gain Python3
import collections
import math
from typing import List


class Solution:
    def calculateMaxInfoGain(self, petal_length: List[float], species: List[str]) -> float:

        # Edge case
        if not petal_length or not species:
            return 0

        # Sort petal_length first to find a place to split
        # [(petal_length, species), (petal_length, species), ...]
        length_species_list = sorted(zip(petal_length, species))
        # Unpack to two lists
        petal_length, species = zip(*length_species_list)

        # To maximize information gain, minimize subgroup weighted entropy sum
        subgroup_weighted_entropy_sum = float('inf')
        n = len(petal_length)
        
        # Try all the splitting position
        for i in range(1, n):
            h_1 = self.calculateEntropy(species[:i])
            h_2 = self.calculateEntropy(species[i:])
            subgroup_weighted_entropy_sum = min(subgroup_weighted_entropy_sum, h_1 * (i / n) + h_2 * ((n - i) / n))

        original_group_entropy = self.calculateEntropy(species)

        return original_group_entropy - subgroup_weighted_entropy_sum

    def calculateEntropy(self, input: List[str]) -> float:
        counter = collections.Counter(input)
        ans = 0
        for count in counter.values():
            p = count / len(input)
            ans += -1 * p * math.log2(p)
        return ans


if __name__ == '__main__':
    # Test
    petal_length = [0.5, 2.3, 1.0, 1.5]
    species = ['setosa', 'versicolor', 'setosa', 'versicolor']
    print(Solution().calculateMaxInfoGain(petal_length, species))
Comments (1)