print("Starting", __name__, "...")
import requests, zipfile, os, time, v
import pandas as pd
from datetime import datetime
from google.transit import gtfs_realtime_pb2
from v import *
from data import URLS, ROUTE_TYPE_NAMES
import builtins

HOURS = 2
AGE = 30



def load_static_gtfs(country, operator, cache_age=AGE*24*3600):
    try: urls = URLS[country][operator]
    except: 
        urls = URLS[country]
        urls["gtfs_url"] = urls["gtfs_url"].replace("_SWE_", operator)
        urls["trip_updates"] = urls["trip_updates"].replace("_SWE_", operator)
        urls["vehicle_positions"] = urls["vehicle_positions"].replace("_SWE_", operator)
    print(urls)
    cache_file = f"cache/gtfs_static_cache_{country}_{operator}.zip"

    need_download = True
    if os.path.exists(cache_file):
        age = time.time() - os.path.getmtime(cache_file)
        if age < cache_age:
            print(f"Using cached GTFS for {country}/{operator} ({age/3600:.1f}h old)")
            need_download = False

    if need_download:
        print(f"Downloading static GTFS for {country}/{operator}…")
        resp = requests.get(urls["gtfs_url"])
        with open(cache_file, "wb") as f:
            f.write(resp.content)
        print("Static GTFS updated")
    z = zipfile.ZipFile(cache_file, "r")
    # Load into memory once
    stops = pd.read_csv(z.open("stops.txt"), dtype=str)
    trips = pd.read_csv(z.open("trips.txt"), dtype=str)
    stop_times = pd.read_csv(z.open("stop_times.txt"), dtype=str)
    routes = pd.read_csv(z.open("routes.txt"), dtype=str)
    # Rename first
    routes = routes.rename(columns={"route_type": "route_type_static"})
    # Convert to int
    routes["route_type_static"] = routes["route_type_static"].astype(int)
    # NOW you can inspect it
    print(routes["route_type_static"].value_counts())
    # Convert once
    routes["route_type_static"] = routes["route_type_static"].astype(int)
    def safe_usecols(df, cols):
        existing = [c for c in cols if c in df.columns]
        return df[existing]

    stops = safe_usecols(stops, ["stop_id", "stop_name", "parent_station"])
    trips = safe_usecols(trips, ["trip_id", "route_id", "direction_id", "trip_headsign", "service_id"])
    routes = safe_usecols(routes, ["route_id", "route_short_name", "route_long_name", "route_type_static"])
    stop_times = safe_usecols(stop_times, [
        "trip_id",
        "stop_id",
        "departure_time",
        "stop_headsign",
        "pickup_type",
        "drop_off_type"])
    
    # Load calendar_dates to filter by service date
    calendar_dates = pd.read_csv(z.open("calendar_dates.txt"), dtype=str)
    calendar_dates["date"] = pd.to_datetime(calendar_dates["date"], format="%Y%m%d").dt.date
    # Keep only exception_type "1" (service is operating)
    calendar_dates = calendar_dates[calendar_dates["exception_type"] == "1"]

    def time_to_seconds(t):
        h, m, s = map(int, t.split(":"))
        return h * 3600 + m * 60 + s
    
    def normalize_seconds(sec):
        if sec >= 24 * 3600:
            return sec - 24 * 3600
        return sec
    stop_times["departure_secs"] = stop_times["departure_time"].apply(time_to_seconds)
    stop_times["normalized_secs"] = stop_times["departure_secs"].apply(normalize_seconds)
    
    v.cache[country][operator] = {
        "stops": stops,
        "trips": trips,
        "stop_times": stop_times,
        "routes": routes,
        "calendar_dates": calendar_dates,
        }

def load_realtime(urls, country, operator):
    """Download and parse GTFS-RT TripUpdates."""
    print("Downloading GTFS-RT…")
    feed = gtfs_realtime_pb2.FeedMessage()
    try: header = URLS[country][operator]["header"]
    except: header = None
    feed.ParseFromString(requests.get(urls["trip_updates"], verify=False, headers=header).content)

    rt_delays = {}
    for entity in feed.entity:
        if entity.HasField("trip_update"):
            tu = entity.trip_update
            for stu in tu.stop_time_update:
                if stu.HasField("departure"):
                    # SNCB often provides departure.time but not delay
                    if stu.departure.HasField("delay"):
                        rt_delays[(tu.trip.trip_id, stu.stop_id)] = stu.departure.delay
                    elif stu.departure.HasField("time"):
                        rt_delays[(tu.trip.trip_id, stu.stop_id)] = True
    return rt_delays

def is_area_id(stop_id: str) -> bool: return stop_id.isdigit() and len(stop_id) == 16 and stop_id.endswith("000")

def search(query, country, operator):
    #try: v.cache[country][operator]
    #except: v.cache[country][operator] = {"Try again":""}
    try: v.cache[country]
    except: v.cache[country] = {}
    try: data = v.cache[country][operator]
    except: data = load_static_gtfs(country, operator)
    
    
    stops = data["stops"].copy()

    stops["stop_id"] = stops["stop_id"].astype(str).fillna("")
    stops["stop_name"] = stops["stop_name"].astype(str).fillna("")

    q = (query or "").lower()
    #q = q[:5]
    

    mask = (
        stops["stop_id"].str.lower().str.contains(q, na=False) |
        stops["stop_name"].str.lower().str.contains(q, na=False))

    matches = stops[mask].copy()

    if country == "se":
        best_rows = []
        for name, group in matches.groupby("stop_name"):
            area_rows = group[group["stop_id"].apply(is_area_id)]
            if len(area_rows):
                best_rows.append(area_rows.iloc[0])
            else:
                best_rows.append(group.iloc[0])
        matches = pd.DataFrame(best_rows)
    
    if country == "de":
        best_rows = []
        for name, group in matches.groupby("stop_name"):
            best_rows.append(group.iloc[0])
        matches = pd.DataFrame(best_rows)

    return dict(zip(matches["stop_name"], matches["stop_id"]))


def normalize_stop_id(stop_id, stops_df, country=None):
    stop_id = str(stop_id)
    stops_df = stops_df.copy()
    stops_df["stop_id"] = stops_df["stop_id"].astype(str)

    # --- SWEDISH STOP MERGING ---
    if country == "se" and stop_id.isdigit() and len(stop_id) == 16:
        base = stop_id[:-1]  # first 15 digits
        siblings = stops_df[
            stops_df["stop_id"].str.startswith(base)
        ]["stop_id"].tolist()

        if siblings:
            return siblings

    # --- EXISTING LOGIC BELOW ---
    if stop_id in stops_df["stop_id"].values:
        children = stops_df[
            stops_df["parent_station"].astype(str) == stop_id
        ]["stop_id"].tolist()
        if children:
            return children
        return [stop_id]

    children = stops_df[
        stops_df["parent_station"].astype(str) == stop_id
    ]["stop_id"].tolist()
    if children:
        return children

    prefix_matches = stops_df[
        stops_df["stop_id"].str.startswith(stop_id + "_")
    ]["stop_id"].tolist()
    if prefix_matches: return prefix_matches
    return []

def departure(stop_id, country, operator):
    try: v.cache[country]
    except: v.cache[country] = {}
    try: v.cache[country][operator]
    except: load_static_gtfs(country, operator)
    
    try: urls = URLS[country][operator]
    except: 
        v.cache[country][operator]#[stop_id] = {"timestamp":"", "message":"Loading..."}
        urls = URLS[country]
        urls["gtfs_url"] = urls["gtfs_url"].replace("_SWE_", operator)
        urls["trip_updates"] = urls["trip_updates"].replace("_SWE_", operator)
        urls["vehicle_positions"] = urls["vehicle_positions"].replace("_SWE_", operator)
    """Return next 2h departures for a given stop_id."""
    #try: print(dir(v.cache[country]))
    #except: search(stop_id, country, operator)
    data = v.cache[country][operator]
    #print(data)
    
    stops = data["stops"]
    trips = data["trips"]
    stop_times = data["stop_times"]
    routes = data["routes"]
    calendar_dates = data["calendar_dates"]

    rt_delays = load_realtime(urls, country, operator)

    now = datetime.datetime.now()
    now_secs = now.hour * 3600 + now.minute * 60 + now.second
    cutoff = now_secs + HOURS * 3600
    service_date = now.date()

    # Normalize stop_id to platform IDs
    platform_ids = normalize_stop_id(stop_id, stops)

    if not platform_ids:
        print("No matching stop_ids found for:", stop_id)
        return []

    # Handle midnight boundary: include departures today and tomorrow if within HOURS of midnight
    # For departures today: normalized_secs >= now_secs
    today_departures = stop_times[
        (stop_times["normalized_secs"] >= now_secs) &
        (stop_times["normalized_secs"] <= cutoff) &
        (stop_times["stop_id"].astype(str).isin(platform_ids))
    ].copy()
    
    # For departures tomorrow (after midnight): normalized_secs < now_secs but original secs >= 24h
    tomorrow_overflow = cutoff - (24 * 3600) if cutoff >= 24 * 3600 else 0
    tomorrow_departures = stop_times[
        (stop_times["departure_secs"] >= 24 * 3600) &
        (stop_times["normalized_secs"] <= tomorrow_overflow) &
        (stop_times["stop_id"].astype(str).isin(platform_ids))
    ].copy()
    
    # Combine both
    future = pd.concat([today_departures, tomorrow_departures], ignore_index=True)
    
    # Merge with trips to get service_id, route_id, and other trip info
    future = (
        future.merge(trips, on="trip_id")
              .merge(routes, on="route_id")
              .merge(stops, on="stop_id"))
    
    # Filter to only include departures where service is operating on today's date OR tomorrow (if near midnight)
    operating_services = calendar_dates[calendar_dates["date"] == service_date]["service_id"].unique().tolist()
    
    # If within HOURS of midnight, also include next day's services (for departures after 24:00)
    if cutoff >= 24 * 3600:
        next_date = service_date + pd.Timedelta(days=1)
        tomorrow_services = calendar_dates[calendar_dates["date"] == next_date]["service_id"].unique().tolist()
        operating_services = list(set(operating_services + tomorrow_services))
    
    future = future[future["service_id"].isin(operating_services)]

 
    future["traffic_type"] = future["route_type_static"].apply(
    lambda x: ROUTE_TYPE_NAMES.get(x, x))

    def apply_delay(row):
        key = (row["trip_id"], row["stop_id"])
        if key in rt_delays:
            return row["departure_secs"] + rt_delays[key]
        return row["departure_secs"]  # fallback to scheduled


    future["rt_departure_secs"] = future.apply(apply_delay, axis=1)
    # Filter out non-realtime departures if operator supports realtime
    # 1 future = future[future.apply(lambda row: (row["trip_id"], row["stop_id"]) in rt_delays, axis=1)]
    # future = future[future.apply(lambda row: (row["trip_id"], row["stop_id"]) in rt_delays, axis=1)]
    # Keep all departures; mark realtime if available
    future["has_realtime"] = future.apply(
    lambda row: (row["trip_id"], row["stop_id"]) in rt_delays,
    axis=1
)


    
    


    # Convert to ISO 8601
    def secs_to_iso8601(row):
        today = datetime.datetime.now().date()
        rt_secs = row["rt_departure_secs"]
        
        # If real-time seconds >= 24h, this is a next-day departure
        if rt_secs >= 24 * 3600:
            target_date = today + pd.Timedelta(days=1)
            time_of_day = rt_secs - 24 * 3600
        else:
            target_date = today
            time_of_day = rt_secs
        
        dt = datetime.datetime.combine(target_date, datetime.datetime.min.time()) + pd.Timedelta(seconds=time_of_day)
        return dt.strftime("%Y-%m-%dT%H:%M:%S")

    future["real_time"] = future.apply(secs_to_iso8601, axis=1)

    future = future.sort_values("rt_departure_secs")
    #future = json.dumps(future)
    print("FUTURE: ", future)
    output = []
    def get_destination(row):
        # 1. stop_headsign
        if pd.notna(row.get("stop_headsign")) and row["stop_headsign"].strip():
            return row["stop_headsign"]

        # 2. trip_headsign
        if pd.notna(row.get("trip_headsign")) and row["trip_headsign"].strip():
            return row["trip_headsign"]

        # 3. route_long_name
        if pd.notna(row.get("route_long_name")) and row["route_long_name"].strip():
            return row["route_long_name"]

        # 4. fallback: last stop of the trip
        trip_id = row["trip_id"]
        seq = stop_times[stop_times["trip_id"] == trip_id]
        if len(seq):
            last_stop_id = seq.iloc[-1]["stop_id"]
            name = stops.loc[stops["stop_id"] == last_stop_id, "stop_name"]
            if len(name):
                return name.iloc[0]

        return ""



    for _, row in future.iterrows():
        direction_code = str(row.get("direction_id", ""))
        #### NORTH / SOUTH
        try:
            if int(direction_code) == 1: direction_code = 2
            if int(direction_code) == 0: direction_code = 1
        except: pass
        ###
        if not direction_code or direction_code == "nan": direction_code = 0
        
        if operator == "sncb" and int(row["pickup_type"]) + int(row["drop_off_type"]): continue

        item = {
            "destination": get_destination(row).split("(")[0],
            "direction_code": direction_code,
            "expected": row["real_time"],
            "line": {
                "id": str(row["route_short_name"]),
                "transport_mode": row["traffic_type"]
            },
            "deviations": [],}
        if not item in output: output.append(item)
    if len(output):
        v.operators[country][operator][stop_id] = output
    return output
