#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:14
# @Author  : 程婷婷
# @FileName: XgboostClassifyProcess.py
# @Software: PyCharm
import numpy as np
from sklearn.utils import class_weight
import pickle
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_selection import mutual_info_classif, SelectPercentile
import scipy.linalg
import jieba
from sklearn.base import BaseEstimator, TransformerMixin
from model.base.views.data.BaseDataProcess import BaseDataProcess
from model.classify.views.xgboost_classify.data.XgoostClassifyDataLoader import XgboostClassifyDataLoader

class Vocabulary:
    def __init__(self, signature, min_word_len=2, name='voc'):
        self.signature = signature
        self.min_word_len = min_word_len
        self.name = name
        self.voc = dict()
        self.freq = dict()
        self.doc_freq = dict()
        self.oov = None
        self.size = 0
        self._fixed_voc = False

    def set_state(self, fixed=False):
        assert fixed in [True, False, 0, 1]
        self._fixed_voc = fixed

    def get_state(self):
        state = 'Fixed' if self._fixed_voc else 'Not fixed'
        return state

    def shuffle(self):
        self.check_state()
        idx = np.random.permutation(self.size)
        shuffled_voc = dict()
        shuffled_freq = dict()
        shuffled_doc_freq = dict()
        for key, id in self.voc.items():
            shuffled_voc[key] = idx[id]
            shuffled_freq[idx[id]] = self.freq[id]
            shuffled_doc_freq[idx[id]] = self.doc_freq[id]
        del self.voc, self.freq, self.doc_freq
        self.voc, self.freq, self.doc_freq = shuffled_voc, shuffled_freq, shuffled_doc_freq

    def _is_useless(self, x):
        if len(x) < self.min_word_len:
            return True
        if x.strip(
                '''#&$_%^*-+=<>`~!@(（）)?？/\\[]{}—"';:：；，。,.‘’“”|…\n abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890''') == '':
            return True
        return False

    def update(self, words):
        if self._fixed_voc:
            raise Exception('Fixed vocabulary does not support update.')
        for word in words:
            if not self._is_useless(word):
                id = self.voc.get(word, None)
                if id is None:  # new word
                    self.voc[word] = self.size
                    self.freq[self.size] = 1
                    self.doc_freq[self.size] = 0  # create doc_freq item
                    self.size += 1
                else:
                    self.freq[id] += 1
        for word in set(words):
            if not self._is_useless(word):
                id = self.voc.get(word, None)
                if id is not None:
                    self.doc_freq[id] += 1  # update doc_freq

    def get(self, word):
        return self.voc.get(word, self.oov)

    def __getitem__(self, word):
        return self.voc.get(word, self.oov)

    def __contains__(self, word):
        return self.voc.__contains__(word)

    def __iter__(self):
        return iter(self.voc)

    def __sizeof__(self):
        return self.voc.__sizeof__() + self.freq.__sizeof__() + self.signature.__sizeof__() + self.size.__sizeof__() + \
               self.name.__sizeof__() + self._fixed_voc.__sizeof__() + self.oov.__sizeof__() + self.doc_freq.__sizeof__()

    def __delitem__(self, word):  # delete would destory the inner representation
        if self._fixed_voc:
            raise Exception('Fixed vocabulary does not support deletion.')
        else:
            raise NotImplementedError

    def get_size(self):
        return self.size

    def clear(self):
        del self.voc, self.freq, self.doc_freq
        self.voc = dict()
        self.freq = dict()
        self.doc_freq = dict()
        self.size = 0
        self._fixed_voc = False

    def check_state(self):
        return len(self.voc) == self.size and len(self.freq) == self.size and len(self.doc_freq) == self.size

    def to_dict(self):
        return self.voc

    def set_signature(self, new_signature):
        self.signature = new_signature

    def save(self, file_name=None):
        save_to = (file_name if file_name else self.name) + '-%s.voc' % self.signature
        with open(save_to, 'wb') as f:
            pickle.dump([self.voc,
                         self.freq,
                         self.doc_freq,
                         self.size,
                         self.min_word_len,
                         self.oov,
                         self._fixed_voc,
                         self.name,
                         self.signature], f)

    @classmethod
    def load(cls, file_name):
        with open(file_name, 'rb') as f:
            [voc, freq, doc_freq, size, min_word_len, oov, _fixed, name, signature] = pickle.load(f)

        voc_from_file = cls(signature, name)
        voc_from_file.voc = voc
        voc_from_file.freq = freq
        voc_from_file.doc_freq = doc_freq
        voc_from_file.size = size
        voc_from_file.min_word_len = min_word_len
        voc_from_file.oov = oov
        voc_from_file._fixed_voc = _fixed
        voc_from_file.signature = signature
        return voc_from_file

class DataProcessor:
    def __init__(self, data, transformer='tf', transformer_norm='l2'):
        self.data = data
        transformer = transformer.lower()
        assert transformer in ['tf', 'tfidf']
        self.transformer_type = transformer
        self.transformer_norm = transformer_norm
        self.transformer = None

    def reset(self):
        self.transformer = None
        self.cv = None

    def preprocess(self, label_dict, _all=False, _emotion=False):
        processed_data = {}
        processed_label = {}
        processed_label_dict = {}
        # only_have_one_label_key = []
        for key in self.data:
            print(key)
            if not _emotion:  # _all=False, _emotion=False
                processed_data[key] = [' '.join(jieba.lcut(str(record[0]))) for record in self.data[key]]
                label = [record[1] for record in self.data[key]]
                processed_label[key] = label
                processed_label_dict[key] = label_dict
                processed_data[key] = np.array(processed_data[key])
                print(processed_label_dict)
        return processed_data, processed_label, processed_label_dict

    def update_vocab(self, vocab, processed_data):
        if type(processed_data) == dict:
            for key in processed_data:
                for record in processed_data[key]:
                    vocab.update(record.split(' '))
        else:
            for record in processed_data:
                vocab.update(record.split(' '))
        assert vocab.check_state(), 'Something wrong with vocabulary.'

    def transform(self, vocab, data, label, with_feature_selection=False, feature_selection_method='FDA', binary=False):
        vocab.set_state(fixed=True)
        assert feature_selection_method in ['FDA', 'SelectPercentile']
        if not self.transformer:
            self.cv = CountVectorizer(decode_error='replace', vocabulary=vocab.to_dict(), binary=binary)
            if self.transformer_type == 'tf':
                self.transformer = TfidfTransformer(norm=self.transformer_norm, use_idf=False)
            else:
                self.transformer = TfidfTransformer(norm=self.transformer_norm, use_idf=True)
        if type(data) == dict:
            transformed_data = {}
            for key in data:
                if with_feature_selection:
                    if feature_selection_method == 'FDA':
                        transformed_data[key] = FDA().fit_transform(
                                                            self.transformer.transform(self.cv.transform(data[key])), label[key]
                                                        )
                    else:
                        transformed_data[key] = SelectPercentile(mutual_info_classif, 20).fit_transform(
                                                            self.transformer.transform(self.cv.transform(data[key])), label[key]
                                                        )
                else:
                    transformed_data[key] = self.transformer.transform(self.cv.transform(data[key]))
        else:
            if with_feature_selection:
                if feature_selection_method == 'FDA':
                    transformed_data = FDA().fit_transform(
                                                            self.transformer.transform(self.cv.transform(data)), label
                                                        )
                else:
                    transformed_data = SelectPercentile(mutual_info_classif, 20).fit_transform(
                                                            self.transformer.transform(self.cv.transform(data)), label
                                                        )
            else:
                transformed_data = self.transformer.transform(self.cv.transform(data))
        return transformed_data


class FDA(BaseEstimator, TransformerMixin):

    def __init__(self, alpha=1e-4):
        '''Fisher discriminant analysis
        Arguments:
        ----------
        alpha : float
            Regularization parameter
        '''

        self.alpha = alpha

    def fit(self, X, Y):
        '''Fit the LDA model
        Parameters
        ----------
        X : array-like, shape [n_samples, n_features]
            Training data
        Y : array-like, shape [n_samples]
            Training labels
        Returns
        -------
        self : object
        '''

        n, d_orig = X.shape
        classes = np.unique(Y)

        assert (len(Y) == n)

        if isinstance(X, scipy.sparse.csr.csr_matrix):
            mean_global = X.mean(axis=0)
        else:
            mean_global = np.mean(X, axis=0, keepdims=True)
        scatter_within = self.alpha * np.eye(d_orig)
        scatter_between = np.zeros_like(scatter_within)

        for c in classes:
            n_c = np.sum(Y == c)
            if n_c < 2:
                continue
            if isinstance(X, scipy.sparse.csr.csr_matrix):
                mu_diff = X[Y == c].mean(axis=0) - mean_global
            else:
                mu_diff = np.mean(X[Y == c], axis=0, keepdims=True) - mean_global
            scatter_between = scatter_between + n_c * np.dot(mu_diff.T, mu_diff)
            if isinstance(X, scipy.sparse.csr.csr_matrix):
                scatter_within = scatter_within + n_c * np.cov(X[Y == c].todense(), rowvar=0)
            else:
                scatter_within = scatter_within + n_c * np.cov(X[Y == c], rowvar=0)

        e_vals, e_vecs = scipy.linalg.eig(scatter_between, scatter_within)

        self.e_vals_ = e_vals
        self.e_vecs_ = e_vecs

        self.components_ = e_vecs.T

        return self

    def transform(self, X):
        '''Transform data by FDA
        Parameters
        ----------
        X : array-like, shape [n_samples, n_features]
            Data to be transformed
        Returns
        -------
        X_new : array, shape (n_samples, n_atoms)
        '''

        return X.dot(self.components_.T)

    def fit_transform(self, X, Y):
        self.fit(X, Y)
        return self.transform(X)


class XgboostClassifyProcess(BaseDataProcess):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.xcdl = XgboostClassifyDataLoader(config_path)

    def class_weight(self, y_train):
        weight = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)
        classes_weight = dict(enumerate(weight))
        return classes_weight

    def runner_process(self, signature):
        df = self.xcdl.read_file()
        all_label = list(set(df['label']))
        self.label_mapping = {v: k for k, v in dict(enumerate(all_label)).items()}
        labels = df['label'].map(self.label_mapping)
        processed_data = df['content'].map(lambda x: ' '.join(jieba.lcut(x)))
        dp = DataProcessor(processed_data,
                           transformer=self.embedding_config['transformer'],
                           transformer_norm=self.embedding_config['transformer_norm'])
        dp.reset()
        vocab = Vocabulary(signature=signature, name='vocab-%s' % self.embedding_config['name'], min_word_len=2)
        dp.update_vocab(vocab, processed_data)
        print('%s, after updating, %s' % (self.embedding_config['name'], vocab.get_size()))
        transformed_data = dp.transform(vocab, processed_data, labels)
        vocab_save_to = self.embedding_config['embedding_path']
        print(vocab.to_dict())
        vocab.save(vocab_save_to)
        merged_data = np.append(transformed_data.toarray(), labels.values.reshape((-1, 1)), axis=1)
        print(merged_data.shape)
        train_set, test_set = self.split_dataset(merged_data, self.process_config['use_dev'])
        return train_set, test_set
# import time
# signature = int(time.time())
# XgboostClassifyProcess().runner_process(signature)