# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/pdf.ipynb.

# %% auto 0
__all__ = ['PdfClient']

# %% ../nbs/pdf.ipynb 3
from reportlab.pdfgen import canvas
from reportlab.pdfbase import pdfmetrics
from reportlab.lib.pagesizes import A4, landscape
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
import hashlib
import os
import json
import requests
from PIL import Image
from reportlab.lib.colors import *
from tqdm import tqdm
from reportlab.lib import colors
import random
from bs4 import BeautifulSoup

# %% ../nbs/pdf.ipynb 4
class PdfClient:
    """PDFを操作するクラス"""

    tmp_dir = "tmp"
    debug = False

    def __init__(self):
        """
        コンストラクタ。一時フォルダを作成します。
        """
        os.makedirs(self.tmp_dir, exist_ok=True)

    def get_manifest_json_from_path(self, path):
        """
        ファイルパスからJSONデータを読み込みます。

        Parameters
        ----------
        path : str
            JSONファイルのパス

        Returns
        -------
        dict
            JSONファイルから読み込んだデータ
        """
        with open(path) as f:
            manifest_json = json.load(f)
        return manifest_json

    def get_manifest_json_from_url(self, iiif_manifest_url):
        """
        URLからJSONデータを取得します。すでに取得済みの場合はそのデータを返します。

        Parameters
        ----------
        iiif_manifest_url : str
            JSONデータのURL

        Returns
        -------
        dict
            URLから取得したJSONデータ
        """
        hs = hashlib.md5(iiif_manifest_url.encode()).hexdigest()
        path = f"{self.tmp_dir}/manifest/{hs}.json"

        if not os.path.exists(path):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            df = requests.get(iiif_manifest_url).json()
            json.dump(df, open(path, "w"), ensure_ascii=False, indent=4)

        return self.get_manifest_json_from_path(path)

    def getHiImagePath(self, image_url):
        tmp = image_url.replace("https://clioimg.hi.u-tokyo.ac.jp/viewer/api/image/idata", "").replace(".jpg", "")
        spl = tmp.split("%2F")
        image_url = "https://clioimg.hi.u-tokyo.ac.jp/viewer/image/idata" # 850/8500/58/0201/0001.tif

        for t in spl:
            if t != "":
                image_url += "/" + t

        image_url += ".tif"
        
        image_hash = hashlib.md5(image_url.encode()).hexdigest()
        img_path = f"{self.task_tmp_dir}/images/{image_hash}.tif"
        

        if not os.path.exists(img_path):
            os.makedirs(os.path.dirname(img_path), exist_ok=True)
            df = requests.get(image_url).content
            open(img_path, "wb").write(df)

        return img_path

    def set_image(self, newPdfPage, im, img_path, page_size):
        is_copressed = self.is_compressed
        compress_quality = self.compress_quality

        if is_copressed:
            im_jpg_path = img_path + "_resized.jpg"
            im.save(im_jpg_path, optimize=True, quality=compress_quality)
        else:
            im_jpg_path = img_path

        newPdfPage.drawImage(im_jpg_path,0,0,width=page_size['width'], height=page_size['height'], preserveAspectRatio=True)

    def get_scale(self, page_size, image_width, image_height):

        p_height = page_size['height'] / image_height
        p_width = page_size['width'] / image_width
        
        return min(p_height, p_width)


    def get_canvases(self):
        canvas_range = self.canvas_range
        iiif_manifest_url = self.iiif_manifest_url
        iiif_manifest_path = self.iiif_manifest_path

        if iiif_manifest_url is not None:
            
            manifest_json = self.get_manifest_json_from_url(iiif_manifest_url)

        elif iiif_manifest_path is not None:
            manifest_json = self.get_manifest_json_from_path(iiif_manifest_path)

        contexts = manifest_json["@context"]
        if not isinstance(contexts, list):
            contexts = [contexts]
        if contexts[0] != "http://iiif.io/api/presentation/3/context.json":
            raise Exception("Not supported context")
        
        canvases = manifest_json["items"]

        if canvas_range is not None:
            canvases = canvases[canvas_range[0]:canvas_range[1]]

        return canvases

    def download_image(self, image_url, i):
        """
        イメージをダウンロードし、そのパスを返します。

        Parameters
        ----------
        image_url : str
            ダウンロードするイメージのURL
        i : int
            イメージのインデックス

        Returns
        -------
        str
            ダウンロードしたイメージのパス
        """
        if self.image_download_dir is None:
            image_hash = hashlib.md5(image_url.encode()).hexdigest()
            img_path = f"{self.task_tmp_dir}/images/{image_hash}.jpg"
        else:
            img_path = f"{self.image_download_dir}/{str(i+1).zfill(4)}.jpg"

        if not os.path.exists(img_path):
            os.makedirs(os.path.dirname(img_path), exist_ok=True)
            image_content = requests.get(image_url).content
            with open(img_path, "wb") as img_file:
                img_file.write(image_content)

        return img_path

    def get_img_path(self, iiif_canvas, i):
        """
        イメージのパスを取得します。もし特定のURLが含まれている場合は特定のメソッドを呼び出し、
        そうでなければ画像をダウンロードしてそのパスを返します。

        Parameters
        ----------
        iiif_canvas : dict
            IIIFマニフェストのキャンバス
        i : int
            キャンバスのインデックス

        Returns
        -------
        str
            イメージのパス
        """
        image_url = iiif_canvas["items"][0]["items"][0]["body"]["id"]

        if "https://clioimg.hi.u-tokyo.ac.jp/viewer/api/image/idata" in image_url:
            img_path = self.getHiImagePath(image_url)
        else:
            img_path = self.download_image(image_url, i)

        return img_path

    def convert_iiif2pdf(self, output_path, iiif_manifest_url = None, iiif_manifest_path = None, post_text_size = 0, default_color = "red", default_alpha=0.0, default_main_color = "gray", default_main_alpha = 0.0, canvas_range=None, font_page_limit = 24, task_id = "base", compress_quality = 100, image_download_dir = None, is_copressed = False, dictionary = {}, with_ocr = True):
        """
        IIIFマニフェストをPDFに変換します。

        Parameters
        ----------
        output_path : str
            出力するPDFのパス
        iiif_manifest_url : str, optional
            IIIFマニフェストのURL
        iiif_manifest_path : str, optional
            IIIFマニフェストのファイルパス
        post_text_size : int, optional
            ポストテキストのサイズ
        default_color : str, optional
            デフォルトの色
        default_alpha : float, optional
            デフォルトの透明度
        default_main_color : str, optional
            メインテキストのデフォルトの色
        default_main_alpha : float, optional
            メインテキストのデフォルトの透明度
        canvas_range : list, optional
            キャンバスの範囲 [start, end]
        font_page_limit : int, optional
            フォントサイズの上限
        task_id : str, optional
            タスクID
        compress_quality : int, optional
            圧縮率
        image_download_dir : str, optional
            画像をダウンロードするディレクトリ
        is_copressed : bool, optional
            画像を圧縮するかどうか
        dictionary : dict, optional
            異体字辞書
        with_ocr : bool, optional
            OCRを行うかどうか
        """

        '''
        if isItaiji:
            self.getItaiji()
        '''

        self.task_tmp_dir = f"{self.tmp_dir}/{task_id}"

        self.isItaiji = True if dictionary != {} else False
        self.dictionary = dictionary
        self.font_page_limit = font_page_limit

        self.post_text_size = post_text_size
        self.default_color = default_color
        self.default_alpha = default_alpha
        self.default_main_color = default_main_color
        self.default_main_alpha = default_main_alpha

        self.is_compressed = is_copressed
        self.compress_quality = compress_quality

        self.canvas_range = canvas_range
        self.iiif_manifest_url = iiif_manifest_url
        self.iiif_manifest_path = iiif_manifest_path

        self.image_download_dir = image_download_dir

        if iiif_manifest_url is None and iiif_manifest_path is None:
            raise Exception('iiif_manifest_url or iiif_manifest_path must be specified.')

        # pass
        newPdfPage = canvas.Canvas(output_path)

        pdfmetrics.registerFont(UnicodeCIDFont('HeiseiKakuGo-W5', isVertical=True))

        canvases = self.get_canvases()

        # for iiif_canvas in tqdm(canvases):
        for i in tqdm(range(len(canvases))):
            iiif_canvas = canvases[i]

            img_path = self.get_img_path(iiif_canvas, i)

            # fliped_img_path = img_path + ".fliped.jpg"
            im = Image.open(img_path)
            image_width, image_height = im.size

            # ページサイズ
            page_size = {}

            ratio = image_width / image_height


            if ratio > 1.0 :
                newPdfPage.setPageSize(landscape(A4))
                page_size['width'], page_size['height'] = landscape(A4)
            else:
                newPdfPage.setPageSize(A4)
                page_size['width'], page_size['height'] = A4

            # 画像のほうが横に長い
            
            # 小さい方のスケールを維持する
            scale = self.get_scale(page_size, image_width, image_height)

            self.set_image(newPdfPage, im, img_path, page_size)

            if with_ocr:
                self.appendOcr(newPdfPage, scale, page_size, iiif_canvas, image_width, image_height)
                

            newPdfPage.showPage()

        newPdfPage.save()

    def appendOcr(self, newPdfPage, scale, page_size, iiif_canvas, image_width, image_height):
        dictionary = self.dictionary
        font_page_limit = self.font_page_limit

        post_text_size = self.post_text_size
        default_main_alpha = self.default_main_alpha
        default_color = self.default_color
        default_alpha = self.default_alpha
        default_main_color = self.default_main_color

        offset_page_x = (page_size['width'] - image_width * scale) / 2
        offset_page_y = (page_size['height'] - image_height * scale) / 2

        annotations = iiif_canvas["annotations"][0]["items"]

        sorted_annotations = self.sort_annotation(annotations)

        prev_group = None

        for i in range(len(sorted_annotations)):
            row = sorted_annotations[i]
            x1 = row["x1"]
            y1 = row["y1"]
            x2 = row["x2"]
            
            text_value = row["text"]
            if self.isItaiji:
                ts = list(text_value)
                for j in range(len(ts)):
                    if ts[j] in dictionary:
                        ts[j] = dictionary[ts[j]]
                text_value = "".join(ts)


            w = row["w"]
            h = row["h"]

            if len(text_value) == 0:
                continue

            anchor_y = image_height - y1
            anchor_x = x1

            text_height = h / len(text_value)
            font_image_size = text_height

            font_page_size = font_image_size * scale

            if font_page_size > font_page_limit:
                # continue
                font_page_size = font_page_limit

            newPdfPage.setFont('HeiseiKakuGo-W5', font_page_size)

            preText = ""

            postText = self.getPostText(i, sorted_annotations, size=post_text_size)

            start = (anchor_y + text_height * len(preText))

            fixed_text_value = preText + text_value + postText

            color, alpha = self.get_color(prev_group, row, default_color=default_color, default_alpha=default_alpha, default_main_color=default_main_color, default_main_alpha=default_main_alpha)

            newPdfPage.setFillColor(color, alpha=alpha)

            # newPdfPage.drawString(scale * anchor_x, scale * start, fixed_text_value)

            x_start = anchor_x + font_image_size / 2
            x_start = (x1 + x2) / 2

            x_page_start = offset_page_x + scale * x_start
            y_page_start = offset_page_y + scale * start

            newPdfPage.drawString(x_page_start, y_page_start, fixed_text_value)

            prev_group = row["group"]

            # break


    def get_color(self, prev_group, row, default_color = "red", default_alpha = 0.5, default_main_color = "gray", default_main_alpha = 0.5):
        color = default_color
        alpha = default_alpha

        if prev_group != row["group"]:
            # Get a list of color names from the colors module
            color_names = [name for name in dir(colors) if isinstance(getattr(colors, name), colors.Color)]

            # Pick a random color name
            random_color_name = random.choice(color_names)

            color = random_color_name
        
        if row["type"] == "本文":
            color = default_main_color
            alpha = default_main_alpha

        return color, alpha

    def sort_annotation(self, annotations):
        xy_map = {}

        for annotation in annotations:
            xywh = annotation["target"].split("#xywh=")[1].split(",")
            x1 = int(xywh[0])
            y1 = int(xywh[1])
            w = int(xywh[2])
            h = int(xywh[3])

            x2 = x1 + w
            y2 = y1 + h

            text = annotation["body"]["value"]

            # text_type
            if text.startswith("【"):
                text_type, text_value = text[1:].split("】")
            else:
                text_type = "本文"
                text_value = text

            row = {
                "x1": x1,
                "y1": y1,
                "x2": x2,
                "y2": y2,
                "text": text_value,
                "type": text_type,
                "w": w,
                "h": h,
            }

            # rows.append(row)

            x_key = str(100000 - x1).zfill(8)
            # y_key = str(y1).zfill(8)

            xy_key = f"{x_key}" # -{y_key}"

            if text_type not in xy_map:
                xy_map[text_type] = {}

            if xy_key not in xy_map[text_type]:
                xy_map[text_type][xy_key] = []

            xy_map[text_type][xy_key].append(row)

        '''
        with open("data/p1/xy_map.json", "w") as f:
            json.dump(xy_map, f, indent=4, ensure_ascii=False)
        '''

        rows2 = []

        for text_type in xy_map:

            rows = []

            tmp_map = xy_map[text_type]

            for xy_key in sorted(tmp_map):
                lines = tmp_map[xy_key]
                for line in lines:
                    rows.append(line)
            

            if text_type == "本文":
                # Initialize the list of groups
                groups = []

                # For each object, check whether it overlaps with an existing group
                for obj in rows:
                    # This flag checks whether the object has been added to a group
                    added = False
                    for group in groups:
                        # If the object overlaps with the group, add it to the group
                        if obj["x1"] <= group[-1]["x2"] and obj["x2"] >= group[0]["x1"]:
                            group.append(obj)
                            group.sort(key=lambda x: x['x1'])  # Sort the group by x_start
                            added = True
                            break
                    # If the object does not overlap with any group, create a new group
                    if not added:
                        groups.append([obj])

                # Print the groups
                for i, group in enumerate(groups):
                    # print(f"Group {i+1}:")
                    # print(json.dumps(group, ensure_ascii=False))

                    tmp = {}

                    for obj in group:

                        y1 = obj["y1"]
                        if y1 not in tmp:
                            tmp[y1] = []

                        tmp[y1].append(obj)

                    for y1 in sorted(tmp):
                        objs = tmp[y1]
                        for obj in objs:
                            obj["group"] = f"{text_type}"
                            rows2.append(obj)

            elif text_type == "頭注":
                objects = rows
                # Calculate the average width of objects
                widths = [(obj["x2"] - obj["x1"]) for obj in objects]
                average_width = sum(widths) / len(widths)

                # Initialize the list of groups
                groups = [[objects[0]]]

                # For each object, check whether it overlaps with the last object in the current group
                for obj in objects[1:]:
                    # If the gap is less than half of the average width, add the object to the current group
                    if groups[-1][-1]["x1"] - obj["x2"]  < average_width / 2:
                        groups[-1].append(obj)
                    else:
                        # Otherwise, start a new group
                        groups.append([obj])

                # Print the groups
                for i, group in enumerate(groups):
                    # print(f"Group {i+1}: {group}")
                    for obj in group:
                        obj["group"] = f"{text_type}-{i + 1}"
                        rows2.append(obj)

            elif text_type == "割注":
                objects = rows
                # Calculate the average width of objects
                widths = [(obj["x2"] - obj["x1"]) for obj in objects]
                average_width = sum(widths) / len(widths)

                # Initialize the list of groups
                groups = [[objects[0]]]

                # For each object, check whether it overlaps with the last object in the current group
                for obj in objects[1:]:
                    # If the gap is less than half of the average width, add the object to the current group
                    if groups[-1][-1]["x1"] - obj["x2"]  < average_width / 2:
                        groups[-1].append(obj)
                    else:
                        # Otherwise, start a new group
                        groups.append([obj])

                # Print the groups
                for i, group in enumerate(groups):
                    # print(f"Group {i+1}: {len(group)} {group}")

                    # このグループのなかで、さらに、y座標に重なりがあるかを調べる

                    groups2 = self.divideByY(group)

                    for j, group2 in enumerate(groups2):

                        for obj in group2:
                            # for obj in group:
                            obj["group"] = f"{text_type}-{i + 1}-{j + 1}"
                            
                            rows2.append(obj)
                            # rows2.append(obj)
            
            else:
                for i, obj in enumerate(rows):
                    obj["group"] = f"{text_type}-{i+1}"
                    rows2.append(obj)

        if self.debug:
            with open(f"{self.tmp_dir}/sorted_annotaions.json", "w") as f:
                json.dump(rows2, f, indent=4, ensure_ascii=False)
        

        return rows2


    def getPostText(self, i, rows, size=5):
        current_type = rows[i]["group"]

        for j in range(i + 1, len(rows)):
            if rows[j]["group"] == current_type:
                return rows[j]["text"][:size]
            
        return ""
    
    def divideByY(self, objects):
        groups = []

        # For each object, check whether it overlaps with an existing group
        for obj in objects:
            # This flag checks whether the object has been added to a group
            added = False
            for group in groups:
                # If the object overlaps with the group, add it to the group
                if obj["y1"] <= group[-1]["y2"] and obj["y2"] >= group[0]["y1"]:
                    group.append(obj)
                    group.sort(key=lambda x: -x['x1'])  # Sort the group by x_start
                    added = True
                    break
            # If the object does not overlap with any group, create a new group
            if not added:
                groups.append([obj])

        '''
        # Print the groups
        for i, group in enumerate(groups):
            # print(f"Group2 {i+1}:")
            # print(json.dumps(group, indent=4, ensure_ascii=False))
            pass
        '''

        return groups

    @staticmethod    
    def getHiItaiji(): # self
        itaijiMap = {}

        url = "https://wwwap.hi.u-tokyo.ac.jp/ships/itaiji_list.jsp"
        r = requests.get(url)
        soup = BeautifulSoup(r.content, "html.parser")

        trs = soup.find_all("tr")
        for tr in trs:
            tds = tr.find_all("td")

            if len(tds) != 3:
                continue

            w_new = tds[1].text
            ws_old = tds[2].text.split("\u3000")

            # print(w_new, w_old)

            for w_old in ws_old:
                if w_old != "\xa0":
                    itaijiMap[w_old] = w_new

        # return itaijiMap
        # self.itaijiMap = itaijiMap
        return itaijiMap

