#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/10 15:28
# @Author  : 程婷婷
# @FileName: BaseDataProcess.py
# @Software: PyCharm
import re
import jieba
import pickle
import gensim
import logging
import numpy as np
import pandas as pd
from pyhanlp import *
from bs4 import BeautifulSoup
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import mutual_info_classif, SelectPercentile
from base.views.config.BaseConfig import BaseConfig
from base.views.data.BaseDataLoader import BaseDataLoader
from platform_zzsn.settings import BASE_DIR

format = '%(asctime)s %(levelname)s %(pathname)s %(funcName)s %(message)s'
logging.basicConfig(format=format, level=logging.INFO)


class BaseDataProcess:
    def __init__(self, config_path):
        self.embedding_config = BaseConfig(config_path)._parsed_file['embedding']
        self.process_config = BaseConfig(config_path)._parsed_file['data_process']
        PerceptronLexicalAnalyzer = JClass('com.hankcs.hanlp.model.perceptron.PerceptronLexicalAnalyzer')
        self.pla_segment = PerceptronLexicalAnalyzer()
        self.bdl = BaseDataLoader(config_path)

    def clean_content(self, content):
        bs = BeautifulSoup(content, 'html.parser')
        return bs.text

    def remove_char(self, content):
        # 保留中文、英语字母、数字和标点
        graph_filter = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s，。\.,？\?!！；;]')
        content = graph_filter.sub('', content)
        return content

    def jieba_tokenizer(self, content):
        if self.process_config['use_stopwords']:
            stopwords = self.bdl.read_stopwords()
        else:
            stopwords = []
        return ' '.join([word for word in jieba.lcut(content) if word not in stopwords])

    def pla_tokenizer(self, content):
        words = list(self.pla_segment.analyze(content).toWordArray())
        if self.process_config['use_stopwords']:
            stopwords = self.bdl.read_stopwords()
        else:
            stopwords = []
        return ' '.join([word for word in words if word not in stopwords])

    def save(self, voc, path):
        with open(path, 'wb') as voc_file:
            pickle.dump(voc, voc_file)

    def process(self, data, min_content=0):
        processed_data = []
        for record in data:
            record = self.clean_content(str(record))
            record = self.remove_char(record)
            if len(record) > min_content:
                methods = self.process_config['tokenizer']
                if methods == 'PerceptronLexicalAnalyzer':
                    record = self.pla_tokenizer(record)
                    record = [row.strip() for row in record if row.strip() != '']
                else:
                    record = self.jieba_tokenizer(record)
                    record = [row.strip() for row in record if row.strip() != '']
                processed_data.append(' '.join(record))
            else:
                pass
        return processed_data

    def split_dataset(self, data, use_dev):
        if use_dev:
            train_data_set, test_dev_set = train_test_split(data,
                                                            train_size=self.process_config['train_size'],
                                                            random_state=self.process_config['random_state'],
                                                            shuffle=True)

            train_data_set, test_data_set, dev_data_set = train_test_split(test_dev_set,
                                                                           test_size=self.process_config['test_size'],
                                                                           random_state=self.process_config['random_state'],
                                                                           shuffle=True)
            print(len(train_data_set) + len(test_data_set) + len(dev_data_set))
            return train_data_set, test_data_set, dev_data_set
        else:
            train_data_set, test_data_set = train_test_split(data,
                                                             train_size=self.process_config['train_size'],
                                                             random_state=self.process_config['random_state'],
                                                             shuffle=True)
            return train_data_set, test_data_set

    def bag_of_words(self, data, label):
        vectorizer = CountVectorizer(ngram_range=(1, 1), min_df=5)
        x = vectorizer.fit_transform(data)
        transformer = TfidfTransformer(norm=self.embedding_config['norm'], use_idf=self.embedding_config['use_idf'],
                                       smooth_idf=self.embedding_config['smooth_idf'])
        x = transformer.fit_transform(x).toarray()
        if self.embedding_config['with_feature_selection']:
            transformed_data = SelectPercentile(mutual_info_classif, 20).fit_transform(x, label)
        else:
            transformed_data = x
        os.makedirs(self.embedding_config['embedding_path'], exist_ok=True)
        self.save(voc=vectorizer.vocabulary_, path=os.path.join(self.embedding_config['embedding_path'], 'tfidf.pkl'))
        return transformed_data, vectorizer.get_feature_names()

    def word2vec(self, data, feature_words):
        model = gensim.models.word2vec.Word2Vec(sentences=data,
                                                size=self.embedding_config['size'],
                                                window=self.embedding_config['window'],
                                                min_count=self.embedding_config['min_count'],
                                                workers=self.embedding_config['workers'],
                                                sg=self.embedding_config['sg'],
                                                iter=self.embedding_config['iter'])
        vocabulary_w2v = model.wv.vocab.keys()
        count = 0
        if self.embedding_config['use_Tencent']:
            model_tencent = gensim.models.KeyedVectors.load_word2vec_format(
                os.path.join(BASE_DIR, 'static/base/Tencent_AILab_ChineseEmbedding.bin'), binary=True)
            vocabulary_tencent = model_tencent.wv.vocab.keys()
            vector_matrix = np.zeros((len(feature_words), int(self.embedding_config['size']) + 200))
            for word in feature_words:
                if word in vocabulary_tencent:
                    vector_tencent = model_tencent.wv.word_vec(word)
                else:
                    vector_tencent = np.random.randn(200)
                if word in vocabulary_w2v:
                    vector_w2v = model.wv.word_vec(word)
                else:
                    vector_w2v = np.random.randn(self.embedding_config['size'])
                vector = np.concatenate((vector_tencent, vector_w2v))
                vector_matrix[count] = vector
                count += 1
        else:
            vector_matrix = np.zeros((len(feature_words), self.embedding_config['size']))
            for word in feature_words:
                if word in vocabulary_w2v:
                    vector_w2v = model.wv.word_vec(word)
                else:
                    vector_w2v = np.random.randn(self.embedding_config['size'])
                vector_matrix[count] = vector_w2v
                count += 1
        os.makedirs(self.embedding_config['embedding_path'], exist_ok=True)
        model.save(os.path.join(self.embedding_config['embedding_path'], 'word2vec.model'))
        return vector_matrix
