# -*- coding: utf-8 -*-
"""L49 ビルドキャッシュ — 1.25M メッシュの重処理を事前計算

本スクリプトは L49_tsunami_inundation.py から呼ばれる、
津波浸水想定 1,256,706 メッシュの重い空間処理を事前計算するワーカー。

  Step 1: メッシュをサンプリング/集約 → 80,000 セル程度に圧縮
           (10m × 10m を 30m × 30m に集約)
  Step 2: 8 ランクに dissolve した polygon を作成 (L08 と同形式)
  Step 3: 各セルに市町コードを付与 (admin sjoin)
  Step 4: 海岸線距離 (沿岸市町の admin 境界ユニオンへの最短距離) を算定
  Step 5: 河川浸水・高潮との空間関係 (各セルを 3 ハザードのどれが含むか)

出力 (data/extras/L49_tsunami_inundation/_cache/):
  - tsunami_dissolve_8rank.gpkg     (8 polygons, EPSG:6671)
  - tsunami_cells_30m.parquet        (集約済みセル, x, y, depth_max, rank, city_cd, dist_coast_m, hits_storm, hits_river)
  - city_rank_pivot.csv              (市町 × ランク 面積 km²)
  - elevation_proxy.csv              (海岸距離 × 深さ ピボット)
  - hazard_overlap_summary.csv       (3 ハザード重複統計)

実行:
  cd "2026 DoBoX 教材"
  py -X utf8 lessons/_l49_build_cache.py
"""
from __future__ import annotations
import sys, time, json
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))
from _common import ROOT

import numpy as np
import pandas as pd
import geopandas as gpd
import shapely
from shapely import STRtree
import pyogrio

t_all = time.time()
print("=== L49 build_cache: 津波 1.25M メッシュ事前計算 ===", flush=True)

TARGET_CRS = "EPSG:6671"
DATA_DIR = ROOT / "data" / "extras" / "L49_tsunami_inundation"
CACHE = DATA_DIR / "_cache"
CACHE.mkdir(parents=True, exist_ok=True)

TSUNAMI_SHP = (ROOT / "data" / "extras" / "tsunami_extracted"
               / "340006_tsunami_inundation_assumption_map_20251203" / "浸水メッシュ.shp")
ADMIN_GPKG = ROOT / "data" / "extras" / "L44_storm_surge" / "_cache" / "admin_diss.gpkg"
STORM_MAX_GPKG = ROOT / "data" / "extras" / "L44_storm_surge" / "_cache" / "diss_max.gpkg"
RIVER_MAX_SHP = (ROOT / "data" / "extras" / "flood_shp"
                  / "shinsui_souteisaidai" / "shinsui_souteisaidai.shp")

# 浸水深 8 ランク (国交省/広島県ガイドライン準拠, L08 と同形式)
# 注: 津波浸水想定の最大値は 8.34m (広島県沿岸は瀬戸内海のため太平洋側より浅い)
RANK_BINS  = [0, 0.5, 1.0, 2.0, 3.0, 5.0, 10.0, 20.0, 999]
RANK_CODES = [10, 20, 30, 40, 50, 60, 70, 80]
RANK_LABEL = {
    10: "0.0〜0.5m", 20: "0.5〜1.0m", 30: "1.0〜2.0m", 40: "2.0〜3.0m",
    50: "3.0〜5.0m", 60: "5.0〜10.0m", 70: "10.0〜20.0m", 80: "20m以上",
}

# 沿岸市町 (L44 と同じ定義)
COASTAL_CITIES = {101,102,103,104,107,108, 202,203,204,205,207, 211,213,215, 304,309}

# =============================================================================
# 1. 1.25M メッシュ読み込み + 30m 集約
# =============================================================================
print("\n[1] 1.25M メッシュ読み込み + 30m 集約", flush=True)
t1 = time.time()

df = pyogrio.read_dataframe(TSUNAMI_SHP, read_geometry=False)
df.columns = ["x", "y", "depth"]
print(f"  原メッシュ: {len(df):,} 行  ({time.time()-t1:.1f}s)")

# 集約: 30m グリッド (x, y を 30m 単位に丸めて max 深さ)
# 1,256,706 → ~140,000 程度に圧縮
df["gx"] = (df["x"] // 30) * 30
df["gy"] = (df["y"] // 30) * 30
agg = df.groupby(["gx","gy"], as_index=False)["depth"].max()
agg.columns = ["x","y","depth"]
agg["rank"] = pd.cut(agg["depth"], bins=RANK_BINS, labels=RANK_CODES, right=False).astype(int)
print(f"  30m 集約後: {len(agg):,} 行 ({100*len(agg)/len(df):.1f}%, {time.time()-t1:.1f}s)")
print(f"  rank 別件数:")
for rk in RANK_CODES:
    n = (agg["rank"]==rk).sum()
    if n>0:
        print(f"    rank={rk:>3} ({RANK_LABEL[rk]}): {n:,}")


# =============================================================================
# 2. 8 ランクに dissolve した polygon (L08 と同形式)
# =============================================================================
print("\n[2] 8 ランク polygon (dissolve)", flush=True)
t2 = time.time()

DISS_GPKG = CACHE / "tsunami_dissolve_8rank.gpkg"
if not DISS_GPKG.exists():
    out_geoms, out_ranks, out_areas = [], [], []
    for rk in RANK_CODES:
        sub = agg[agg["rank"] == rk]
        if len(sub) == 0:
            continue
        # 30m × 30m 正方形 (中心 ±15m)
        polys = shapely.box(
            sub["x"].values - 15, sub["y"].values - 15,
            sub["x"].values + 15, sub["y"].values + 15,
        )
        merged = shapely.unary_union(polys)
        out_geoms.append(merged)
        out_ranks.append(rk)
        out_areas.append(merged.area / 1e6)
        print(f"    rank={rk}: {len(sub):,} cells → 1 polygon, {merged.area/1e6:.2f} km²")
    custom_crs = ("+proj=tmerc +lat_0=36 +lon_0=132.166666 +k=0.9999 "
                  "+x_0=0 +y_0=0 +ellps=WGS84 +units=m +no_defs")
    gdf = gpd.GeoDataFrame({"rank": out_ranks, "area_km2": out_areas, "geometry": out_geoms},
                            crs=custom_crs)
    # CRS は EPSG:6671 とほぼ同等なので set_crs で割り当て
    gdf = gdf.set_crs(TARGET_CRS, allow_override=True)
    gdf.to_file(DISS_GPKG, driver="GPKG")
    print(f"  saved {DISS_GPKG.name} ({DISS_GPKG.stat().st_size/1e6:.1f} MB, {time.time()-t2:.1f}s)")
else:
    print(f"  cached {DISS_GPKG.name} ({time.time()-t2:.1f}s)")


# =============================================================================
# 3. 各セルに市町コード付与 (admin sjoin)
# =============================================================================
print("\n[3] 各セルに市町コードを付与 (admin sjoin)", flush=True)
t3 = time.time()

CELLS_PARQUET = CACHE / "tsunami_cells_30m.parquet"
if not CELLS_PARQUET.exists():
    admin = gpd.read_file(ADMIN_GPKG).to_crs(TARGET_CRS)
    print(f"  admin: {len(admin)} polys")

    # GeoDataFrame 化 (point geometry, EPSG:6671)
    pts = gpd.GeoDataFrame(
        agg.copy(),
        geometry=gpd.points_from_xy(agg["x"], agg["y"]),
        crs=TARGET_CRS,
    )
    print(f"  point gdf: {len(pts):,} rows")

    # sjoin (predicate='within') で市町コード付与
    pts_admin = gpd.sjoin(pts, admin[["CITY_CD", "geometry"]],
                            how="left", predicate="within")
    pts_admin = pts_admin.drop(columns="index_right")
    print(f"  sjoin done, 市町なしセル: {pts_admin['CITY_CD'].isna().sum():,} "
          f"({time.time()-t3:.1f}s)")

    # NaN は -1 (海上 / 県外) として保持
    pts_admin["CITY_CD"] = pts_admin["CITY_CD"].fillna(-1).astype(int)
    pts_admin.to_parquet(CELLS_PARQUET, index=False)
    print(f"  saved {CELLS_PARQUET.name}")
else:
    print(f"  cached {CELLS_PARQUET.name}")
    pts_admin = gpd.read_parquet(CELLS_PARQUET)


# =============================================================================
# 4. 海岸距離 (沿岸市町の admin 境界ユニオン) — 各セルから最短距離 (m)
# =============================================================================
print("\n[4] 海岸距離計算 (沿岸市町境界からの距離)", flush=True)
t4 = time.time()

if "dist_coast_m" not in pts_admin.columns:
    admin = gpd.read_file(ADMIN_GPKG).to_crs(TARGET_CRS)
    coastal = admin[admin["CITY_CD"].isin(COASTAL_CITIES)].copy()
    print(f"  沿岸市町数: {len(coastal)} (CITY_CD in {sorted(COASTAL_CITIES)})")

    # 沿岸市町ポリゴンの境界 (line) を取り出し union
    boundary_union = coastal.geometry.boundary.union_all()
    # 注: この境界 = 沿岸市町ポリゴンの全周 (海と内陸両方を含む)
    # 「海岸線」 だけを切り出すには別の処理が必要だが、
    # 津波浸水域はそのほとんどが沿岸市町内かつ海寄りに分布するため、
    # この境界を「海岸候補線」 として扱える近似となる。
    # 実装簡素化のためこのまま使う。

    print(f"  boundary union geom_type: {boundary_union.geom_type}, "
          f"length: {boundary_union.length/1000:.1f} km")

    # 各セルから境界 union への最短距離 (shapely STRtree で高速化)
    bnd_lines = list(coastal.geometry.boundary.values)
    tree = STRtree(bnd_lines)
    pts_geom = pts_admin.geometry.values

    # nearest 計算 (バッチ)
    dists = np.zeros(len(pts_geom), dtype=np.float32)
    BATCH = 50000
    for k in range(0, len(pts_geom), BATCH):
        ks = pts_geom[k:k+BATCH]
        idx = tree.nearest(ks)
        for i, ki in enumerate(ks):
            dists[k+i] = ki.distance(bnd_lines[idx[i]])
        if (k//BATCH) % 4 == 0:
            print(f"    {k+BATCH:,}/{len(pts_geom):,} done, "
                  f"{time.time()-t4:.1f}s", flush=True)

    pts_admin["dist_coast_m"] = dists
    print(f"  距離 quantile: q10={np.quantile(dists,0.1):.1f}, "
          f"q50={np.quantile(dists,0.5):.1f}, q90={np.quantile(dists,0.9):.1f}, "
          f"max={dists.max():.1f}", flush=True)
    pts_admin.to_parquet(CELLS_PARQUET, index=False)


# =============================================================================
# 5. 高潮 max ・河川 max との重なり (各セルがどのハザードに含まれるか)
# =============================================================================
print("\n[5] 高潮・河川との重なり", flush=True)
t5 = time.time()

if "hits_storm" not in pts_admin.columns:
    # 高潮 max polygon (L44 既キャッシュ)
    storm = gpd.read_file(STORM_MAX_GPKG).to_crs(TARGET_CRS)
    storm_union = storm.geometry.union_all()
    print(f"  storm max union ready ({time.time()-t5:.1f}s)")

    # 河川 想定最大規模 polygon (L08 と同 Shapefile)
    river = pyogrio.read_dataframe(RIVER_MAX_SHP)
    if river.crs is None or "JGD" not in str(river.crs):
        river = river.to_crs(TARGET_CRS)
    else:
        river = river.to_crs(TARGET_CRS)
    river["geometry"] = gpd.GeoSeries(shapely.force_2d(river.geometry.values), crs=TARGET_CRS)
    river_union = river.geometry.union_all()
    print(f"  river max union ready ({time.time()-t5:.1f}s)")

    pts_geom = pts_admin.geometry.values
    # contains を batch で
    hits_storm = np.zeros(len(pts_geom), dtype=bool)
    hits_river = np.zeros(len(pts_geom), dtype=bool)

    # STRtree-based contains via .contains_xy (faster) — fall back to per-batch
    tree_storm = STRtree([storm_union])
    tree_river = STRtree([river_union])

    BATCH = 100000
    for k in range(0, len(pts_geom), BATCH):
        ks = pts_geom[k:k+BATCH]
        # query: which point is within storm_union geometry
        idx_s = tree_storm.query(ks, predicate="within")
        if len(idx_s):
            hits_storm[idx_s[0]] = True
        idx_r = tree_river.query(ks, predicate="within")
        if len(idx_r):
            hits_river[idx_r[0]] = True
        if (k//BATCH) % 2 == 0:
            print(f"    {min(k+BATCH, len(pts_geom)):,}/{len(pts_geom):,} done, "
                  f"{time.time()-t5:.1f}s", flush=True)

    pts_admin["hits_storm"] = hits_storm
    pts_admin["hits_river"] = hits_river
    print(f"  storm 重なり: {hits_storm.sum():,} cells ({100*hits_storm.mean():.2f}%)")
    print(f"  river 重なり: {hits_river.sum():,} cells ({100*hits_river.mean():.2f}%)")
    pts_admin.to_parquet(CELLS_PARQUET, index=False)


# =============================================================================
# 6. ピボット集計 + 保存
# =============================================================================
print("\n[6] ピボット集計", flush=True)
t6 = time.time()

# 6a. 市町 × ランク 面積 (km²)
df = pts_admin.drop(columns="geometry").copy() if "geometry" in pts_admin.columns else pts_admin.copy()
df["area_m2"] = 30 * 30  # 集約後 30m メッシュ面積
city_rank = df.groupby(["CITY_CD","rank"])["area_m2"].sum().unstack(fill_value=0)
city_rank = city_rank / 1e6  # km²
city_rank.columns = [f"rank_{c}_km2" for c in city_rank.columns]
city_rank["total_km2"] = city_rank.sum(axis=1)
city_rank = city_rank.reset_index().sort_values("total_km2", ascending=False)
city_rank.to_csv(CACHE / "city_rank_pivot.csv", index=False, encoding="utf-8-sig")
print(f"  city_rank_pivot: {len(city_rank)} rows")

# 6b. 海岸距離ビン × 深さ ピボット (要 RQ2)
df["dist_bin"] = pd.cut(df["dist_coast_m"],
                         bins=[0, 50, 100, 200, 500, 1000, 2000, 5000, 1e9],
                         labels=["0-50m","50-100m","100-200m","200-500m",
                                "500-1km","1-2km","2-5km","5km+"])
elevation_proxy = df.groupby(["dist_bin","rank"], observed=True)["area_m2"].sum().unstack(fill_value=0)
elevation_proxy = elevation_proxy / 1e6
elevation_proxy = elevation_proxy.reset_index()
elevation_proxy.to_csv(CACHE / "elevation_proxy.csv", index=False, encoding="utf-8-sig")
print(f"  elevation_proxy: {len(elevation_proxy)} rows")

# 6c. 3 ハザード重複サマリ (8 セル: 津波 alone, 津波+高潮, 津波+河川, 津波+高潮+河川, ...)
df["pattern"] = (
    df["hits_storm"].astype(int) * 2 +
    df["hits_river"].astype(int) * 1
)
# 0 = tsunami only, 1 = +river, 2 = +storm, 3 = +storm+river
hazard_overlap = df.groupby("pattern")["area_m2"].agg(["sum","count"]).reset_index()
hazard_overlap["sum"] = hazard_overlap["sum"] / 1e6  # km²
hazard_overlap.columns = ["pattern","area_km2","n_cells"]
labels_pattern = {0:"津波のみ", 1:"津波+河川", 2:"津波+高潮", 3:"津波+高潮+河川"}
hazard_overlap["label"] = hazard_overlap["pattern"].map(labels_pattern)
hazard_overlap.to_csv(CACHE / "hazard_overlap_summary.csv", index=False, encoding="utf-8-sig")
print(f"  hazard_overlap saved ({len(hazard_overlap)} rows)")

# 6d. 深さ × 海岸距離 × 高潮重複 のクロス
deep_dist_storm = df.groupby(["rank","hits_storm"])["area_m2"].sum().unstack(fill_value=0)
deep_dist_storm = deep_dist_storm / 1e6
deep_dist_storm.columns = ["alone","with_storm"]
deep_dist_storm = deep_dist_storm.reset_index()
deep_dist_storm.to_csv(CACHE / "rank_storm_cross.csv", index=False, encoding="utf-8-sig")

print(f"\n=== build_cache DONE in {time.time()-t_all:.1f}s ===")
