import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)



class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()

        START_TAG = "<START>"
        STOP_TAG = "<STOP>"
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)


        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        # 双向LSTM后的全连接层，将其转化为概率
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # 转移矩阵其中i,j的含义为从i状态转化为j状态得分
        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        # 下面两条语句说明永远不可能从一个状态转化为start状态，也不可能从stop状态转移至任一状态
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        START_TAG = "<START>"
        STOP_TAG = "<STOP>"
        # 进行前向传播来计算划分函数
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        # START_TAG 有得分为0
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        # 给另一个变量赋值以便自动反向传播
        forward_var = init_alphas

        # 在句子中迭代 feats为BiLSTM得出的每个标签的得分
        for feat in feats:
            alphas_t = []  # 本次循环的前向传播张量
            for next_tag in range(self.tagset_size):
                # 传播emission score，该得分与先前的标签无关
                emit_score = feat[next_tag].view(
                    1, -1).expand(1, self.tagset_size)

                # 从i变为下一个标签的transitions score
                trans_score = self.transitions[next_tag].view(1, -1)
                # next_tag_var是从i转换到另一个标签的得分
                next_tag_var = forward_var + trans_score + emit_score
                # 做完log-sum-exp后的前向传播张量，也是最后转换的得分T
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        alpha = log_sum_exp(terminal_var)
        return alpha



    def _viterbi_decode(self, feats):
        START_TAG = "<START>"
        STOP_TAG = "<STOP>"
        backpointers = []

        # 在空间中初始化维特比变量
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        # 在第i步forward_var为第i-1步的值
        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []  # 该次循环的反向路径
            viterbivars_t = []  # 该次循环维特比变量

            for next_tag in range(self.tagset_size):
                # next_tag_var[i]是上一次循环中标签i的得分加上i转换为下一标签的transition score
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # 现在加上emission score，并且将forward_score加到维特比变量中
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # 转换至STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # 根据backpointers来选取解码的最佳路径
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)

        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return path_score, best_path



    def forward(self, sentence):
        START_TAG = "<START>"
        STOP_TAG = "<STOP>"
        # 获取BiLSTM的emission score
        lstm_feats = self._get_lstm_features(sentence)

        # 在给定的emission score中寻找标注的最佳路径
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

