# -*- coding: utf-8 -*-
"""L63 build_cache — 津波災害警戒区域 1.28M ポリゴンの重処理を事前計算

本スクリプトは L63_tsunami_warning_zone.py から呼ばれる、
津波災害警戒区域の 1,285,428 polygon の重い空間処理を事前計算するワーカー。

  Step 1: 元 polygon (= 1.28M セル) を 30m グリッドに集約
           polygonの代表点 (重心または bbox 中心) → 30m 単位丸め → max 基準水位
  Step 2: 8 ランクに dissolve した polygon 作成
  Step 3: 各セルに市町コードを付与 (admin sjoin, L44/L49 既キャッシュ admin 流用)
  Step 4: L49 津波浸水想定 cells (既キャッシュ) との空間関係 (重なり)
  Step 5: 集計ピボット作成

  注: kijyun_sin は cm 単位の基準水位 (基準水位 = 浸水深を cm で記録、
       value/10 = m)。範囲 10-628 = 1.0〜62.8m だが、実際は 1.0〜6.28m が現実的
       なので /100 する解釈もあり得る。確認のため統計を出す。

出力 (data/extras/L63_tsunami_warning_zone/_cache/):
  - keikai_dissolve_8rank.gpkg     (8 polygons, EPSG:6671)
  - keikai_cells_30m.parquet        (集約済みセル: x, y, kijyun_m, rank, city_cd, in_l49)
  - city_rank_pivot.csv              (市町 × ランク 面積 km²)
  - l49_overlap_summary.csv          (L49 想定と本警戒区域の重なり統計)
  - rank_l49_cross.csv               (深さランク × L49 重複)
"""
from __future__ import annotations
import sys, time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))
from _common import ROOT

import numpy as np
import pandas as pd
import geopandas as gpd
import shapely
from shapely import STRtree
import pyogrio

t_all = time.time()
print("=== L63 build_cache: 津波警戒区域 1.28M ポリゴン事前計算 ===", flush=True)

TARGET_CRS = "EPSG:6671"
DATA_DIR = ROOT / "data" / "extras" / "L63_tsunami_warning_zone"
CACHE = DATA_DIR / "_cache"
CACHE.mkdir(parents=True, exist_ok=True)

KEIKAI_SHP = DATA_DIR / "extracted" / "tsunami_keikai.shp"
ADMIN_GPKG = ROOT / "data" / "extras" / "L44_storm_surge" / "_cache" / "admin_diss.gpkg"

# L49 想定区域キャッシュ (既存)
L49_CELLS = ROOT / "data" / "extras" / "L49_tsunami_inundation" / "_cache" / "tsunami_cells_30m.parquet"
L49_DISS  = ROOT / "data" / "extras" / "L49_tsunami_inundation" / "_cache" / "tsunami_dissolve_8rank.gpkg"

# 浸水深 8 ランク (L49 と同形式に揃える)
RANK_BINS  = [0, 0.5, 1.0, 2.0, 3.0, 5.0, 10.0, 20.0, 999]
RANK_CODES = [10, 20, 30, 40, 50, 60, 70, 80]
RANK_LABEL = {
    10: "0.0〜0.5m", 20: "0.5〜1.0m", 30: "1.0〜2.0m", 40: "2.0〜3.0m",
    50: "3.0〜5.0m", 60: "5.0〜10.0m", 70: "10.0〜20.0m", 80: "20m以上",
}

# kijyun_sin の解釈:
#   観測 min=10, max=628 → /100 と /10 の二択。
#   /100 解釈: 0.10〜6.28m → 瀬戸内海津波として妥当
#   /10  解釈: 1.0〜62.8m  → 過大 (太平洋側を超える)
#   よって /100 を採用。
KIJYUN_SCALE = 0.01


# =============================================================================
# 1. 元 Shapefile 読込 + 重心抽出 + 30m 集約
# =============================================================================
print("\n[1] Shapefile 読込 + 30m 集約", flush=True)
t1 = time.time()

CELLS_PARQUET = CACHE / "keikai_cells_30m.parquet"

if not CELLS_PARQUET.exists() or "in_l49" not in pd.read_parquet(CELLS_PARQUET, columns=None).columns.tolist():
    # 全レコードを読む (重い: 174MB shp, 1.28M polygon)
    print(f"  read_dataframe (1.28M polygon)...", flush=True)
    gdf = pyogrio.read_dataframe(KEIKAI_SHP)
    print(f"  read done: {len(gdf):,} rows  ({time.time()-t1:.1f}s)", flush=True)
    print(f"  CRS: {gdf.crs}")
    print(f"  kijyun_sin describe: min={gdf['kijyun_sin'].min()}, "
          f"median={gdf['kijyun_sin'].median()}, max={gdf['kijyun_sin'].max()}", flush=True)

    # 重心を取り出して x,y を 10m 整数グリッドへ snap (= polygon の中心)
    cx = gdf.geometry.centroid.x.values
    cy = gdf.geometry.centroid.y.values
    gdf_xy = pd.DataFrame({
        "x": cx.astype(np.int64),
        "y": cy.astype(np.int64),
        "kijyun_sin": gdf["kijyun_sin"].astype(np.int64).values,
    })
    del gdf

    # 浸水深 m
    gdf_xy["depth_m"] = gdf_xy["kijyun_sin"] * KIJYUN_SCALE

    # 30m 集約 (max kijyun)
    gdf_xy["gx"] = (gdf_xy["x"] // 30) * 30
    gdf_xy["gy"] = (gdf_xy["y"] // 30) * 30
    agg = gdf_xy.groupby(["gx", "gy"], as_index=False)["depth_m"].max()
    agg.columns = ["x", "y", "depth_m"]
    agg["rank"] = pd.cut(agg["depth_m"], bins=RANK_BINS,
                          labels=RANK_CODES, right=False).astype(int)

    print(f"  30m 集約後: {len(agg):,} rows ({100*len(agg)/1285428:.1f}%, "
          f"{time.time()-t1:.1f}s)", flush=True)
    print(f"  rank 別件数:")
    for rk in RANK_CODES:
        n = (agg["rank"] == rk).sum()
        if n > 0:
            print(f"    rank={rk:>3} ({RANK_LABEL[rk]}): {n:,}")

    # CRS 変換: EPSG:2445 → 6671 (日本の平面直角第3系系統; ほぼ同一)
    # 元ファイル CRS が 2445 だが、L49 と同じ TMerc 系なのでそのまま set_crs
    pts = gpd.GeoDataFrame(
        agg.copy(),
        geometry=gpd.points_from_xy(agg["x"], agg["y"]),
        crs="EPSG:2445",
    ).to_crs(TARGET_CRS)

    # x, y を再度 6671 系で上書き
    pts["x"] = pts.geometry.x.astype(np.int64)
    pts["y"] = pts.geometry.y.astype(np.int64)

    # admin sjoin で市町コード付与
    print(f"\n[1b] 市町 sjoin", flush=True)
    admin = gpd.read_file(ADMIN_GPKG).to_crs(TARGET_CRS)
    print(f"  admin: {len(admin)} polys")
    pts_admin = gpd.sjoin(pts, admin[["CITY_CD", "geometry"]],
                          how="left", predicate="within")
    pts_admin = pts_admin.drop(columns="index_right")
    n_outside = pts_admin["CITY_CD"].isna().sum()
    print(f"  市町なしセル (海上): {n_outside:,}")
    pts_admin["CITY_CD"] = pts_admin["CITY_CD"].fillna(-1).astype(int)

    # L49 想定区域との重なり判定
    print(f"\n[1c] L49 想定区域との重なり (in_l49 フラグ)", flush=True)
    if L49_CELLS.exists():
        l49_cells = gpd.read_parquet(L49_CELLS)
        # L49 cells の x,y は EPSG:6671 系
        # 同一座標系に揃えてセルレベル一致を確認
        # L49 30m 集約 cells のキー (x,y) と本研究の (x,y) を 30m 単位で比較
        l49_cells["gx"] = (l49_cells["x"] // 30) * 30
        l49_cells["gy"] = (l49_cells["y"] // 30) * 30
        l49_keys = set(zip(l49_cells["gx"].values, l49_cells["gy"].values))

        pts_admin["gx"] = (pts_admin["x"] // 30) * 30
        pts_admin["gy"] = (pts_admin["y"] // 30) * 30
        in_l49 = np.array([(int(gx), int(gy)) in l49_keys
                            for gx, gy in zip(pts_admin["gx"].values,
                                                pts_admin["gy"].values)],
                           dtype=bool)
        pts_admin["in_l49"] = in_l49
        n_in = int(in_l49.sum())
        n_out = int((~in_l49).sum())
        print(f"  L49 想定セル数: {len(l49_cells):,}, L63 警戒セル数: {len(pts_admin):,}")
        print(f"  L63 ∩ L49 (= 想定にも警戒にも入る): {n_in:,} ({100*n_in/len(pts_admin):.1f}%)")
        print(f"  L63 \\ L49 (= 警戒のみ): {n_out:,} ({100*n_out/len(pts_admin):.1f}%)")
    else:
        print(f"  WARN: L49 cells キャッシュなし → in_l49 は全 False")
        pts_admin["in_l49"] = False

    pts_admin.to_parquet(CELLS_PARQUET, index=False)
    print(f"  saved {CELLS_PARQUET.name} ({CELLS_PARQUET.stat().st_size/1e6:.1f} MB)")
else:
    print(f"  cached {CELLS_PARQUET.name}")
    pts_admin = gpd.read_parquet(CELLS_PARQUET)
print(f"  ({time.time()-t1:.1f}s)", flush=True)


# =============================================================================
# 2. 8 ランクに dissolve した polygon
# =============================================================================
print("\n[2] 8 ランク dissolve polygon", flush=True)
t2 = time.time()

DISS_GPKG = CACHE / "keikai_dissolve_8rank.gpkg"
if not DISS_GPKG.exists():
    out_geoms, out_ranks, out_areas = [], [], []
    for rk in RANK_CODES:
        sub = pts_admin[pts_admin["rank"] == rk]
        if len(sub) == 0:
            continue
        polys = shapely.box(
            sub["x"].values - 15, sub["y"].values - 15,
            sub["x"].values + 15, sub["y"].values + 15,
        )
        merged = shapely.unary_union(polys)
        out_geoms.append(merged)
        out_ranks.append(rk)
        out_areas.append(merged.area / 1e6)
        print(f"    rank={rk}: {len(sub):,} cells → 1 polygon, "
              f"{merged.area/1e6:.2f} km²")
    gdf = gpd.GeoDataFrame({"rank": out_ranks, "area_km2": out_areas,
                              "geometry": out_geoms},
                             crs=TARGET_CRS)
    gdf.to_file(DISS_GPKG, driver="GPKG")
    print(f"  saved {DISS_GPKG.name} ({DISS_GPKG.stat().st_size/1e6:.1f} MB)")
else:
    print(f"  cached {DISS_GPKG.name}")
print(f"  ({time.time()-t2:.1f}s)", flush=True)


# =============================================================================
# 3. ピボット集計
# =============================================================================
print("\n[3] ピボット集計", flush=True)
t3 = time.time()

df = pts_admin.drop(columns="geometry").copy() if "geometry" in pts_admin.columns else pts_admin.copy()
df["area_m2"] = 30 * 30  # 30m × 30m = 900 m²

# 3a. 市町 × ランク 面積 (km²)
city_rank = df.groupby(["CITY_CD", "rank"])["area_m2"].sum().unstack(fill_value=0)
city_rank = city_rank / 1e6
city_rank.columns = [f"rank_{c}_km2" for c in city_rank.columns]
city_rank["total_km2"] = city_rank.sum(axis=1)
city_rank = city_rank.reset_index().sort_values("total_km2", ascending=False)
city_rank.to_csv(CACHE / "city_rank_pivot.csv", index=False, encoding="utf-8-sig")
print(f"  city_rank_pivot: {len(city_rank)} rows")

# 3b. L49 重なりサマリ
l49_overlap = df.groupby("in_l49")["area_m2"].agg(["sum", "count"]).reset_index()
l49_overlap["sum"] = l49_overlap["sum"] / 1e6
l49_overlap.columns = ["in_l49", "area_km2", "n_cells"]
l49_overlap["label"] = l49_overlap["in_l49"].map({True: "想定 ∩ 警戒 (両方)",
                                                    False: "警戒のみ (想定外)"})
l49_overlap.to_csv(CACHE / "l49_overlap_summary.csv", index=False, encoding="utf-8-sig")
print(f"  l49_overlap rows: {len(l49_overlap)}")

# 3c. ランク × L49 重複クロス
rank_l49 = df.groupby(["rank", "in_l49"])["area_m2"].sum().unstack(fill_value=0)
rank_l49 = rank_l49 / 1e6
if True in rank_l49.columns and False in rank_l49.columns:
    rank_l49.columns = ["only_keikai", "both"]
elif True in rank_l49.columns:
    rank_l49.columns = ["both"]
elif False in rank_l49.columns:
    rank_l49.columns = ["only_keikai"]
rank_l49 = rank_l49.reset_index()
rank_l49.to_csv(CACHE / "rank_l49_cross.csv", index=False, encoding="utf-8-sig")
print(f"  rank_l49_cross: {len(rank_l49)} rows")

# 3d. 市町 × L49 重複
city_l49 = df.groupby(["CITY_CD", "in_l49"])["area_m2"].sum().unstack(fill_value=0)
city_l49 = city_l49 / 1e6
city_l49.columns = ["only_keikai_km2" if c is False else "both_km2"
                     for c in city_l49.columns]
city_l49["total_km2"] = city_l49.sum(axis=1)
city_l49 = city_l49.reset_index().sort_values("total_km2", ascending=False)
city_l49.to_csv(CACHE / "city_l49_cross.csv", index=False, encoding="utf-8-sig")
print(f"  city_l49_cross: {len(city_l49)} rows")

print(f"  ({time.time()-t3:.1f}s)", flush=True)
print(f"\n=== build_cache DONE in {time.time()-t_all:.1f}s ===")
