#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:14
# @Author  : 程婷婷
# @FileName: XgboostClassifyProcess.py
# @Software: PyCharm
from flair.data import Sentence, Corpus
import re
from transformers import AutoTokenizer
from torch.utils.data import Dataset
from flair.embeddings import TransformerDocumentEmbeddings
from model.base import BaseDataProcess
from model.classify import FlairClassifyDataLoader


class DataSet(Dataset):
    def __init__(
            self, data_df, tokenizer,
    ):
        df = data_df.copy()
        sep_token = tokenizer.special_tokens_map['sep_token']
        self.samples = df.content.apply(lambda s: re.sub("<sep>", sep_token, s)).values
        self.labels = df.label.values
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample, label = self.samples[index], self.labels[index]
        sentence = Sentence(sample, use_tokenizer=self.tokenizer.tokenize)
        if not len(sentence):
            sentence = Sentence(self.tokenizer.unk_token, use_tokenizer=self.tokenizer.tokenize)
            print(sample)
            print(sentence)
        sentence.add_label('class', str(label))
        return sentence


class FlairClassifyProcess(BaseDataProcess):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.fcdl = FlairClassifyDataLoader(config_path)

    @staticmethod
    def add_sep_token(content):
        return re.sub('。', '。<sep>', content)

    def runner_process(self):
        df = self.fcdl.read_file()
        df = df[df.content.apply(lambda s: s.strip()).apply(len) > 10]
        df = df.reset_index(drop=True)
        df['content'] = df['content'].apply(lambda s: self.add_sep_token(str(s)))
        pos = df.label.value_counts()
        loss_weights = (pos.sum() - pos) / pos
        self.loss_weights = loss_weights.to_dict()
        if self.process_config['label_encode']:
            all_label = list(set(df['label']))
            self.label_mapping = {v: k for k, v in dict(enumerate(all_label)).items()}
            labels = df['label'].map(self.label_mapping)
            print(labels)
        tokenizer = AutoTokenizer.from_pretrained(self.embedding_config['pretrained_name'])
        if self.process_config['use_dev']:
            train_data_set, test_data_set, dev_data_set = self.split_dataset(df, use_dev=self.process_config['use_dev'])
            train_set = DataSet(train_data_set, tokenizer)
            test_set = DataSet(test_data_set, tokenizer)
            val_set = DataSet(dev_data_set, tokenizer)
            corpus = Corpus(train=train_set, dev=val_set, test=test_set)
        else:
            train_data_set, test_data_set = self.split_dataset(df, use_dev=self.process_config['use_dev'])
            train_set = DataSet(train_data_set, tokenizer)
            test_set = DataSet(test_data_set, tokenizer)
            corpus = Corpus(train=train_set, test=test_set)
        label_dict = corpus.make_label_dictionary()
        document_embeddings = TransformerDocumentEmbeddings(
            self.embedding_config['pretrained_name'], fine_tune=True
        )
        return corpus, document_embeddings, label_dict, loss_weights
