// Copyright (c) Meta Platforms, Inc. and affiliates.

// This software may be used and distributed according to the terms of the
// GNU General Public License version 2.
pub mod stats;
use stats::Metrics;

use std::mem::MaybeUninit;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;

use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use clap::Parser;
use crossbeam::channel::RecvTimeoutError;
use libbpf_rs::skel::Skel;
use libbpf_rs::AsRawLibbpf;
use libbpf_rs::MapCore as _;
use libbpf_rs::OpenObject;
use libbpf_rs::ProgramInput;
use scx_arena::ArenaLib;
use scx_stats::prelude::*;
use scx_utils::build_id;
use scx_utils::compat;
use scx_utils::init_libbpf_logging;
use scx_utils::libbpf_clap_opts::LibbpfOpts;
use scx_utils::pm::{
    cpu_idle_resume_latency_supported, epp_supported, for_each_uncore_domain, get_epp,
    get_turbo_enabled, get_uncore_max_freq_khz, get_uncore_min_freq_khz, set_epp,
    set_turbo_enabled, set_uncore_max_freq_khz, turbo_supported, uncore_freq_supported,
    update_cpu_idle_resume_latency,
};
use scx_utils::scx_ops_attach;
use scx_utils::scx_ops_load;
use scx_utils::scx_ops_open;
use scx_utils::uei_exited;
use scx_utils::uei_report;
use scx_utils::Topology;
use scx_utils::UserExitInfo;
use scx_utils::NR_CPU_IDS;
use tracing::{debug, info, warn};
use tracing_subscriber::filter::EnvFilter;

use bpf_intf::stat_idx_P2DQ_NR_STATS;
use bpf_intf::stat_idx_P2DQ_STAT_ATQ_ENQ;
use bpf_intf::stat_idx_P2DQ_STAT_ATQ_REENQ;
use bpf_intf::stat_idx_P2DQ_STAT_DIRECT;
use bpf_intf::stat_idx_P2DQ_STAT_DISPATCH_PICK2;
use bpf_intf::stat_idx_P2DQ_STAT_DSQ_CHANGE;
use bpf_intf::stat_idx_P2DQ_STAT_DSQ_SAME;
use bpf_intf::stat_idx_P2DQ_STAT_EAS_BIG_SELECT;
use bpf_intf::stat_idx_P2DQ_STAT_EAS_FALLBACK;
use bpf_intf::stat_idx_P2DQ_STAT_EAS_LITTLE_SELECT;
use bpf_intf::stat_idx_P2DQ_STAT_ENQ_CPU;
use bpf_intf::stat_idx_P2DQ_STAT_ENQ_INTR;
use bpf_intf::stat_idx_P2DQ_STAT_ENQ_LLC;
use bpf_intf::stat_idx_P2DQ_STAT_ENQ_MIG;
use bpf_intf::stat_idx_P2DQ_STAT_EXEC_BALANCE;
use bpf_intf::stat_idx_P2DQ_STAT_EXEC_SAME_LLC;
use bpf_intf::stat_idx_P2DQ_STAT_FORK_BALANCE;
use bpf_intf::stat_idx_P2DQ_STAT_FORK_SAME_LLC;
use bpf_intf::stat_idx_P2DQ_STAT_IDLE;
use bpf_intf::stat_idx_P2DQ_STAT_KEEP;
use bpf_intf::stat_idx_P2DQ_STAT_LLC_MIGRATION;
use bpf_intf::stat_idx_P2DQ_STAT_NODE_MIGRATION;
use bpf_intf::stat_idx_P2DQ_STAT_SELECT_PICK2;
use bpf_intf::stat_idx_P2DQ_STAT_THERMAL_AVOID;
use bpf_intf::stat_idx_P2DQ_STAT_THERMAL_KICK;
use bpf_intf::stat_idx_P2DQ_STAT_WAKE_LLC;
use bpf_intf::stat_idx_P2DQ_STAT_WAKE_MIG;
use bpf_intf::stat_idx_P2DQ_STAT_WAKE_PREV;
use scx_p2dq::bpf_intf;
use scx_p2dq::bpf_skel::*;
use scx_p2dq::SchedulerOpts;
use scx_p2dq::TOPO;

const SCHEDULER_NAME: &str = "scx_p2dq";
/// scx_p2dq: A pick 2 dumb queuing load balancing scheduler.
///
/// The BPF part does simple vtime or round robin scheduling in each domain
/// while tracking average load of each domain and duty cycle of each task.
///
#[derive(Debug, Parser)]
struct CliOpts {
    /// Deprecated, noop, use RUST_LOG or --log-level instead.
    #[clap(short = 'v', long, action = clap::ArgAction::Count)]
    verbose: u8,

    /// Specify the logging level. Accepts rust's envfilter syntax for modular
    /// logging: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax. Examples: ["info", "warn,tokio=info"]
    #[clap(long, default_value = "info")]
    pub log_level: String,

    /// Enable stats monitoring with the specified interval.
    #[clap(long)]
    pub stats: Option<f64>,

    /// Run in stats monitoring mode with the specified interval. Scheduler
    /// is not launched.
    #[clap(long)]
    pub monitor: Option<f64>,

    /// Print version and exit.
    #[clap(long)]
    pub version: bool,

    /// Optional run ID for tracking scheduler instances.
    #[clap(long)]
    pub run_id: Option<u64>,

    #[clap(flatten)]
    pub sched: SchedulerOpts,

    #[clap(flatten, next_help_heading = "Libbpf Options")]
    pub libbpf: LibbpfOpts,
}

struct Scheduler<'a> {
    skel: BpfSkel<'a>,
    struct_ops: Option<libbpf_rs::Link>,
    debug_level: u8,

    stats_server: StatsServer<(), Metrics>,
}

impl<'a> Scheduler<'a> {
    fn init(
        opts: &SchedulerOpts,
        libbpf_ops: &LibbpfOpts,
        open_object: &'a mut MaybeUninit<OpenObject>,
        log_level: &str,
    ) -> Result<Self> {
        // Open the BPF prog first for verification.
        let debug_level = if log_level.contains("trace") {
            2
        } else if log_level.contains("debug") {
            1
        } else {
            0
        };
        let mut skel_builder = BpfSkelBuilder::default();
        skel_builder.obj_builder.debug(debug_level > 1);
        init_libbpf_logging(None);
        info!(
            "Running scx_p2dq (build ID: {})",
            build_id::full_version(env!("CARGO_PKG_VERSION"))
        );
        let topo = if opts.virt_llc_enabled {
            Topology::with_args(&opts.topo)?
        } else {
            Topology::new()?
        };
        let open_opts = libbpf_ops.clone().into_bpf_open_opts();
        let mut open_skel = scx_ops_open!(skel_builder, open_object, p2dq, open_opts).context(
            "Failed to open BPF object. This can be caused by a mismatch between the kernel \
            version and the BPF object, permission or other libbpf issues. Try running `dmesg \
            | grep bpf` to see if there are any error messages related to the BPF object. See \
            the LibbpfOptions section in the help for more information on configuration related \
            to this issue or file an issue on the scx repo if the problem persists. \
            https://github.com/sched-ext/scx/issues/new?labels=scx_p2dq&title=scx_p2dq:%20New%20Issue&assignees=hodgesds&body=Kernel%20version:%20(fill%20me%20out)%0ADistribution:%20(fill%20me%20out)%0AHardware:%20(fill%20me%20out)%0A%0AIssue:%20(fill%20me%20out)"
        )?;

        // Disable autoload for thermal pressure tracepoint by default
        // Will be conditionally enabled if kernel supports it
        // Note: This tracepoint only exists on ARM/ARM64 architectures
        #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
        open_skel.progs.on_thermal_pressure.set_autoload(false);

        // Apply hardware-specific optimizations before macro
        let hw_profile = scx_p2dq::HardwareProfile::detect();
        let mut opts_optimized = opts.clone();
        if opts.hw_auto_optimize {
            hw_profile.optimize_scheduler_opts(&mut opts_optimized);
        }

        scx_p2dq::init_open_skel!(
            &mut open_skel,
            topo,
            &opts_optimized,
            debug_level,
            &hw_profile
        )?;

        // Thermal pressure tracking (ARM/ARM64 only)
        #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
        {
            let thermal_enabled = std::path::Path::new(
                "/sys/kernel/tracing/events/thermal_pressure/hw_pressure_update",
            )
            .exists()
                || std::path::Path::new(
                    "/sys/kernel/debug/tracing/events/thermal_pressure/hw_pressure_update",
                )
                .exists();

            if thermal_enabled {
                debug!(
                    "Kernel supports thermal pressure tracking, enabling hw_pressure_update tracepoint"
                );
                open_skel.progs.on_thermal_pressure.set_autoload(true);
                stats::set_thermal_tracking_enabled(true);

                open_skel
                    .maps
                    .rodata_data
                    .as_mut()
                    .unwrap()
                    .p2dq_config
                    .thermal_enabled = std::mem::MaybeUninit::new(true);
            } else {
                debug!("Kernel does not support thermal pressure tracking (CONFIG_SCHED_HW_PRESSURE not enabled)");
            }
        }

        if opts_optimized.enable_eas {
            stats::set_eas_enabled(true);
        }

        if opts_optimized.atq_enabled && compat::ksym_exists("bpf_spin_unlock").unwrap_or(false) {
            stats::set_atq_enabled(true);
        }

        if opts.queued_wakeup {
            open_skel.struct_ops.p2dq_mut().flags |= *compat::SCX_OPS_ALLOW_QUEUED_WAKEUP;
        }
        open_skel.struct_ops.p2dq_mut().flags |= *compat::SCX_OPS_KEEP_BUILTIN_IDLE;

        // Disable autoattach for the struct_ops map since we attach it manually via
        // attach_struct_ops() in scx_ops_attach!(). This prevents libbpf from warning
        // about uninitialized skeleton link during attach().
        unsafe {
            libbpf_rs::libbpf_sys::bpf_map__set_autoattach(
                open_skel.maps.p2dq.as_libbpf_object().as_ptr(),
                false,
            );
        }

        let mut skel = scx_ops_load!(open_skel, p2dq, uei)?;
        scx_p2dq::init_skel!(&mut skel, topo);

        let stats_server = StatsServer::new(stats::server_data()).launch()?;

        Ok(Self {
            skel,
            struct_ops: None,
            debug_level,
            stats_server,
        })
    }

    fn get_metrics(&self) -> Metrics {
        let mut stats = vec![0u64; stat_idx_P2DQ_NR_STATS as usize];
        let stats_map = &self.skel.maps.stats;
        for stat in 0..stat_idx_P2DQ_NR_STATS {
            let cpu_stat_vec: Vec<Vec<u8>> = stats_map
                .lookup_percpu(&stat.to_ne_bytes(), libbpf_rs::MapFlags::ANY)
                .unwrap()
                .unwrap();
            let sum: u64 = cpu_stat_vec
                .iter()
                .map(|val| u64::from_ne_bytes(val.as_slice().try_into().unwrap()))
                .sum();
            stats[stat as usize] = sum;
        }
        Metrics {
            atq_enq: stats[stat_idx_P2DQ_STAT_ATQ_ENQ as usize],
            atq_reenq: stats[stat_idx_P2DQ_STAT_ATQ_REENQ as usize],
            direct: stats[stat_idx_P2DQ_STAT_DIRECT as usize],
            idle: stats[stat_idx_P2DQ_STAT_IDLE as usize],
            dsq_change: stats[stat_idx_P2DQ_STAT_DSQ_CHANGE as usize],
            same_dsq: stats[stat_idx_P2DQ_STAT_DSQ_SAME as usize],
            keep: stats[stat_idx_P2DQ_STAT_KEEP as usize],
            enq_cpu: stats[stat_idx_P2DQ_STAT_ENQ_CPU as usize],
            enq_intr: stats[stat_idx_P2DQ_STAT_ENQ_INTR as usize],
            enq_llc: stats[stat_idx_P2DQ_STAT_ENQ_LLC as usize],
            enq_mig: stats[stat_idx_P2DQ_STAT_ENQ_MIG as usize],
            select_pick2: stats[stat_idx_P2DQ_STAT_SELECT_PICK2 as usize],
            dispatch_pick2: stats[stat_idx_P2DQ_STAT_DISPATCH_PICK2 as usize],
            llc_migrations: stats[stat_idx_P2DQ_STAT_LLC_MIGRATION as usize],
            node_migrations: stats[stat_idx_P2DQ_STAT_NODE_MIGRATION as usize],
            wake_prev: stats[stat_idx_P2DQ_STAT_WAKE_PREV as usize],
            wake_llc: stats[stat_idx_P2DQ_STAT_WAKE_LLC as usize],
            wake_mig: stats[stat_idx_P2DQ_STAT_WAKE_MIG as usize],
            fork_balance: stats[stat_idx_P2DQ_STAT_FORK_BALANCE as usize],
            exec_balance: stats[stat_idx_P2DQ_STAT_EXEC_BALANCE as usize],
            fork_same_llc: stats[stat_idx_P2DQ_STAT_FORK_SAME_LLC as usize],
            exec_same_llc: stats[stat_idx_P2DQ_STAT_EXEC_SAME_LLC as usize],
            thermal_kick: stats[stat_idx_P2DQ_STAT_THERMAL_KICK as usize],
            thermal_avoid: stats[stat_idx_P2DQ_STAT_THERMAL_AVOID as usize],
            eas_little_select: stats[stat_idx_P2DQ_STAT_EAS_LITTLE_SELECT as usize],
            eas_big_select: stats[stat_idx_P2DQ_STAT_EAS_BIG_SELECT as usize],
            eas_fallback: stats[stat_idx_P2DQ_STAT_EAS_FALLBACK as usize],
        }
    }

    fn run(&mut self, shutdown: Arc<AtomicBool>) -> Result<UserExitInfo> {
        let (res_ch, req_ch) = self.stats_server.channels();

        while !shutdown.load(Ordering::Relaxed) && !uei_exited!(&self.skel, uei) {
            match req_ch.recv_timeout(Duration::from_secs(1)) {
                Ok(()) => res_ch.send(self.get_metrics())?,
                Err(RecvTimeoutError::Timeout) => {}
                Err(e) => Err(e)?,
            }
        }

        let _ = self.struct_ops.take();
        uei_report!(&self.skel, uei)
    }

    fn print_topology(&mut self) -> Result<()> {
        let input = ProgramInput {
            ..Default::default()
        };

        let output = self.skel.progs.arena_topology_print.test_run(input)?;
        if output.return_value != 0 {
            bail!(
                "Could not initialize arenas, topo_print returned {}",
                output.return_value as i32
            );
        }

        Ok(())
    }

    fn start(&mut self) -> Result<()> {
        self.struct_ops = Some(scx_ops_attach!(self.skel, p2dq)?);

        if self.debug_level > 0 {
            self.print_topology()?;
        }

        info!("P2DQ scheduler started! Run `scx_p2dq --monitor` for metrics.");

        Ok(())
    }
}

impl Drop for Scheduler<'_> {
    fn drop(&mut self) {
        info!("Unregister {SCHEDULER_NAME} scheduler");

        if let Some(struct_ops) = self.struct_ops.take() {
            drop(struct_ops);
        }
    }
}

#[clap_main::clap_main]
fn main(opts: CliOpts) -> Result<()> {
    if opts.version {
        println!(
            "scx_p2dq: {}",
            build_id::full_version(env!("CARGO_PKG_VERSION"))
        );
        return Ok(());
    }

    let env_filter = EnvFilter::try_from_default_env()
        .or_else(|_| match EnvFilter::try_new(&opts.log_level) {
            Ok(filter) => Ok(filter),
            Err(e) => {
                eprintln!(
                    "invalid log envvar: {}, using info, err is: {}",
                    opts.log_level, e
                );
                EnvFilter::try_new("info")
            }
        })
        .unwrap_or_else(|_| EnvFilter::new("info"));

    match tracing_subscriber::fmt()
        .with_env_filter(env_filter)
        .with_target(true)
        .with_thread_ids(true)
        .with_file(true)
        .with_line_number(true)
        .try_init()
    {
        Ok(()) => {}
        Err(e) => eprintln!("failed to init logger: {}", e),
    }

    if opts.verbose > 0 {
        warn!("Setting verbose via -v is deprecated and will be an error in future releases.");
    }

    if let Some(run_id) = opts.run_id {
        info!("scx_p2dq run_id: {}", run_id);
    }

    let shutdown = Arc::new(AtomicBool::new(false));
    let shutdown_clone = shutdown.clone();
    ctrlc::set_handler(move || {
        shutdown_clone.store(true, Ordering::Relaxed);
    })
    .context("Error setting Ctrl-C handler")?;

    if let Some(intv) = opts.monitor.or(opts.stats) {
        let shutdown_copy = shutdown.clone();
        let jh = std::thread::spawn(move || {
            match stats::monitor(Duration::from_secs_f64(intv), shutdown_copy) {
                Ok(_) => {
                    debug!("stats monitor thread finished successfully")
                }
                Err(error_object) => {
                    warn!("stats monitor thread finished because of an error {error_object}")
                }
            }
        });
        if opts.monitor.is_some() {
            let _ = jh.join();
            return Ok(());
        }
    }

    if let Some(idle_resume_us) = opts.sched.idle_resume_us {
        if !cpu_idle_resume_latency_supported() {
            warn!("idle resume latency not supported");
        } else if idle_resume_us > 0 {
            info!("Setting idle QoS to {idle_resume_us}us");
            for cpu in TOPO.all_cpus.values() {
                update_cpu_idle_resume_latency(cpu.id, idle_resume_us.try_into().unwrap())?;
            }
        }
    }

    let is_efficiency = opts.sched.sched_mode == scx_p2dq::SchedMode::Efficiency;
    let is_performance = opts.sched.sched_mode == scx_p2dq::SchedMode::Performance;

    let mut orig_uncore_freqs: Vec<(u32, u32, u32)> = Vec::new();
    if opts.sched.uncore_max_freq_mhz.is_some() || is_efficiency || is_performance {
        if !uncore_freq_supported() {
            if opts.sched.uncore_max_freq_mhz.is_some() {
                warn!("uncore frequency control not supported");
            }
        } else {
            let _ = for_each_uncore_domain(|pkg, die| {
                let freq_khz = if let Some(mhz) = opts.sched.uncore_max_freq_mhz {
                    mhz * 1000
                } else if is_efficiency {
                    get_uncore_min_freq_khz(pkg, die)?
                } else {
                    get_uncore_max_freq_khz(pkg, die)?
                };
                if let Ok(orig) = get_uncore_max_freq_khz(pkg, die) {
                    if orig != freq_khz {
                        info!(
                            "Setting max uncore frequency for package {} die {} to {} MHz",
                            pkg,
                            die,
                            freq_khz / 1000
                        );
                        orig_uncore_freqs.push((pkg, die, orig));
                        set_uncore_max_freq_khz(pkg, die, freq_khz)?;
                    }
                }
                Ok(())
            });
        }
    }

    let mut orig_epps: Vec<(usize, String)> = Vec::new();
    if (is_efficiency || is_performance) && epp_supported() {
        let target_epp = if is_efficiency {
            "power"
        } else {
            "performance"
        };
        for cpu in TOPO.all_cpus.values() {
            if let Ok(orig) = get_epp(cpu.id) {
                if orig != target_epp {
                    if orig_epps.is_empty() {
                        info!("Setting EPP to {} for all CPUs", target_epp);
                    }
                    orig_epps.push((cpu.id, orig));
                    let _ = set_epp(cpu.id, target_epp);
                }
            }
        }
    }

    let orig_turbo = if turbo_supported() {
        let target_turbo = opts.sched.turbo.or(if is_efficiency {
            Some(false)
        } else if is_performance {
            Some(true)
        } else {
            None
        });
        if let Some(want_enabled) = target_turbo {
            if let Ok(current) = get_turbo_enabled() {
                if current != want_enabled {
                    let mode_suffix = if opts.sched.turbo.is_none() {
                        if is_efficiency {
                            " for efficiency mode"
                        } else {
                            " for performance mode"
                        }
                    } else {
                        ""
                    };
                    info!(
                        "{} turbo{}",
                        if want_enabled {
                            "Enabling"
                        } else {
                            "Disabling"
                        },
                        mode_suffix
                    );
                    let _ = set_turbo_enabled(want_enabled);
                    Some(current)
                } else {
                    None
                }
            } else {
                None
            }
        } else {
            None
        }
    } else {
        if opts.sched.turbo.is_some() {
            warn!("turbo control not supported");
        }
        None
    };

    let mut open_object = MaybeUninit::uninit();
    loop {
        let mut sched =
            Scheduler::init(&opts.sched, &opts.libbpf, &mut open_object, &opts.log_level)?;
        let task_size = std::mem::size_of::<types::task_p2dq>();
        let arenalib = ArenaLib::init(sched.skel.object_mut(), task_size, *NR_CPU_IDS)?;
        arenalib.setup()?;

        sched.start()?;

        if !sched.run(shutdown.clone())?.should_restart() {
            break;
        }
    }

    if let Some(was_enabled) = orig_turbo {
        info!(
            "Restoring turbo to {}",
            if was_enabled { "enabled" } else { "disabled" }
        );
        let _ = set_turbo_enabled(was_enabled);
    }

    if !orig_epps.is_empty() {
        info!("Restoring EPP settings");
        for (cpu, epp) in orig_epps {
            let _ = set_epp(cpu, &epp);
        }
    }

    for (pkg, die, orig_khz) in orig_uncore_freqs {
        info!(
            "Restoring uncore frequency for package {} die {} to {} MHz",
            pkg,
            die,
            orig_khz / 1000
        );
        let _ = set_uncore_max_freq_khz(pkg, die, orig_khz);
    }

    Ok(())
}
