# -*- coding: utf-8 -*-

"""
@Time    :  2023/3/16 11:03
@Author  : 
@FileName: 
@Software: 
@Describe:
"""
from rouge import Rouge
rouge = Rouge()
from copy import deepcopy

class Rouge_w:
    def __init__(self):
        self.k = 0.1
        self.ki = 1.2
        self.p = 1.0

    def fi_(self,a):
        return a * self.ki

    def f(self, a):
        return self.k * (a ** 2)

    def WLCS(self, X, Y, f):
        m = len(X)
        n = len(Y)
        c = [[0 for j in range(n+1)] for i in range(m+1)]
        w = [[0 for j in range(n+1)] for i in range(m+1)]

        for i in range(1, m+1):
            for j in range(1, n+1):
                if X[i-1] == Y[j-1]:
                    k = w[i-1][j-1]
                    c[i][j] = c[i-1][j-1] + 10.0 * (f(k+1) - f(k))
                    w[i][j] = k+1
                else:
                    if c[i-1][j] > c[i][j-1]:
                        c[i][j] = c[i-1][j]
                        w[i][j] = 0
                    else:
                        c[i][j] = c[i][j-1]
                        w[i][j] = 0

        return c[m][n]

    def f_1(self, k):
        return k ** 0.5

    def f_(self, k):
        return k ** 2

# print(WLCS([1,2,5], [1,2,5],f))

    def score(self, p, r):
        m = len(p)
        n = len(r)
        wlcs = self.WLCS(p, r, self.f)
        p_wlcs = self.f_1(wlcs/self.f_(m))
        r_wlcs = self.f_1(wlcs/self.f_(n))
        f_lcs = (1 + self.p **2) * ((p_wlcs * r_wlcs) / (p_wlcs + ((self.p ** 2) *r_wlcs) + 1e-8))
        return f_lcs

class Rouge_l:
    def __init__(self):
        self.b = 3

    def LCS(self, X, Y):
        m = len(X)
        n = len(Y)
        # 创建一个二维数组来存储中间结果
        dp = [[0] * (n + 1) for _ in range(m + 1)]

        # 使用动态规划填充dp数组
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if X[i - 1] == Y[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

        return dp[m][n]

# print(WLCS([1,2,5], [1,2,5],f))

    def score(self, p, r):
        m = len(p)
        n = len(r)
        lcs = self.LCS(p, r)
        p_lcs = lcs/m
        r_lcs = lcs/n
        f_lcs = ((1 + self.b ** 2) * (p_lcs * r_lcs) / (p_lcs + self.b ** 2 * r_lcs + 1e-8))
        return f_lcs


# class Ngrams(object):
#     """
#         Ngrams datastructure based on `set` or `list`
#         depending in `exclusive`
#     """
#
#     def __init__(self, ngrams={}, exclusive=True):
#         if exclusive:
#             self._ngrams = set(ngrams)
#         else:
#             self._ngrams = list(ngrams)
#         self.exclusive = exclusive
#
#     def add(self, o):
#         if self.exclusive:
#             self._ngrams.add(o)
#         else:
#             self._ngrams.append(o)
#
#     def __len__(self):
#         return len(self._ngrams)
#
#     def intersection(self, o):
#         if self.exclusive:
#             inter_set = self._ngrams.intersection(o._ngrams)
#             return Ngrams(inter_set, exclusive=True)
#         else:
#             other_list = deepcopy(o._ngrams)
#             inter_list = []
#
#             for e in self._ngrams:
#                 try:
#                     i = other_list.index(e)
#                 except ValueError:
#                     continue
#                 other_list.pop(i)
#                 inter_list.append(e)
#             return Ngrams(inter_list, exclusive=False)
#
#     def union(self, *ngrams):
#         if self.exclusive:
#             union_set = self._ngrams
#             for o in ngrams:
#                 union_set = union_set.union(o._ngrams)
#             return Ngrams(union_set, exclusive=True)
#         else:
#             union_list = deepcopy(self._ngrams)
#             for o in ngrams:
#                 union_list.extend(o._ngrams)
#             return Ngrams(union_list, exclusive=False)
#
# class Rouge_l:
#     def __init__(self):
#
#     def score(self, evaluated_sentences, reference_sentences, raw_results=False, exclusive=True, **_):
#         if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
#             raise ValueError("Collections must contain at least 1 sentence.")
#
#         # total number of words in reference sentences
#         m = len(
#             Ngrams(
#                 _split_into_words(reference_sentences),
#                 exclusive=exclusive))
#
#         # total number of words in evaluated sentences
#         n = len(
#             Ngrams(
#                 _split_into_words(evaluated_sentences),
#                 exclusive=exclusive))
#
#         # print("m,n %d %d" % (m, n))
#         union_lcs_sum_across_all_references = 0
#         union = Ngrams(exclusive=exclusive)
#         for ref_s in reference_sentences:
#             lcs_count, union = _union_lcs(evaluated_sentences,
#                                           ref_s,
#                                           prev_union=union,
#                                           exclusive=exclusive)
#             union_lcs_sum_across_all_references += lcs_count
#
#         llcs = union_lcs_sum_across_all_references
#         r_lcs = llcs / m
#         p_lcs = llcs / n
#
#         f_lcs = 2.0 * ((p_lcs * r_lcs) / (p_lcs + r_lcs + 1e-8))

if __name__ == '__main__':

    rouge_model = Rouge_l()
    X =  ["A", "B", "C", "D", "u", "u", "u", "u", "u", "u"]
    Y1 = ["A", "B", "C", "D", "H", "I", "K", "K", "K", "K", "K", "K"]
    Y2 = ["A", "H", "B", "K", "C", "I", "K", "K", "K", "K", "K", "K"]
    # X = "我爱你"
    # Y = "我他爱"
    print(rouge_model.score(X, Y1))
    # print(WLCS([1,2,5], [1,2,5],f))