Pythonでベイジアンフィルタの実装

# -*- coding: utf-8 -*-
"""ベイジアンフィルタ実装."""
import math
import sys

class NaiveBayes(object):
    """単純ナイーブベイズクラス."""
    def __init__(self):
        """コンストラクタ.
        @set() vocabularies
        @dict word_count
        @dict category_count
        """

        self.vocabularies = set()
        self.word_count = {}
        self.category_count = {}

    def count_word(self, word, category):
        """単語カウンタ.
        @param (string, string)
        """

        self.word_count.setdefault(category, {})
        self.word_count[category].setdefault(word, 0)
        self.word_count[category][word] += 1
        self.vocabularies.add(word)

    def count_category(self, category):
        """カテゴリカウンタ.
        @param (string)
        """

        self.category_count.setdefault(category, 0)
        self.category_count[category] += 1

    def train(self, text, category):
        """学習データ作成.
        @param (string, string)
        """

        """
        words = 形態素解析でキーワード抽出
                YahooAPI,mecab, etc...
        """

        for word in words:
            self.count_word(word, category)

        self.count_category(category)

    def prior_prob(self, category):
        """categoryの生起確率計算.
        @param (string)
        """
        num_category = sum(self.category_count.values())
        num_docs = self.category_count[category]

        return num_docs / num_category

    def num_of_appearance(self, word, category):
        """wordの出現回数.
        @param (string, string)
        """
        if word in self.word_count[category]:
            return self.word_count[category][word]

        return 0

    def word_prob(self, word, category):
        """ベイズの法則.
        @param (string, string)
        """
        numerator = self.num_of_appearance(word, category) + 1
        denominator = sum(self.word_count[category].values()) + len(self.vocabularies)

        prob = numerator / denominator

        return prob

    def score(self, words, category):
        """スコア計算.
        @param (string, string)
        """
        score = math.log(self.prior_prob(category))
        for word in words:
            score += math.log(self.word_prob(word, category))

        return score

    def classify(self, doc):
        """分類器.
        @param (string)
        """
        best_guessed_category = None
        max_prob_before = -sys.maxsize

        ansys = AnalysisSelectData()
        keyphrase_dict = ansys.extract_keyphrase(doc)
        words = ansys.keyphrase_tuple(keyphrase_dict)

        for category in self.category_count.keys():
            prob = self.score(words, category)
            if prob > max_prob_before:
                max_prob_before = prob
                best_guessed_category = category

        return best_guessed_category

nb = NaiveBayes()
nb.train("料理(りょうり)は、食物をこしらえることで、同時に、こしらえた結果である食品そのもの[1]。調理ともいう[1]。すなわち、食材、調味料などを組み合わせて加工を行うこと、およびそれを行ったものの総称である。",
             "料理")
nb.train("ネコ(猫)は、狭義にはネコ目(食肉目)- ネコ亜目- ネコ科- ネコ亜科- ネコ属-ヤマネコ種-イエネコ亜種に分類される小型哺乳類であるイエネコ(家猫、学名:Felis silvestriscatus)の通称である。人間によくなつくため、イヌ(犬)と並ぶ代表的なペットとして世界中で広く飼われている。より広義には、ヤマネコやネコ科動物全般を指すこともある","ネコ")

print(nb.classify("食物"))
#=> "料理"