use std::{io::{Write}, fs::File};

extern crate rand;
use rand::*;

fn main() {
    println!("cargo:rerun-if-changed=build.rs");
    let mut file = std::fs::File::create("src/magic/tables.rs").unwrap();
    create_tables(&mut file);
}




type Bitboard = u64;
type Square = u8;

pub const FILE_A: Bitboard = 0x_8080_8080_8080_8080;
pub const FILE_B: Bitboard = FILE_A >> 1;
pub const FILE_C: Bitboard = FILE_A >> 2;
pub const FILE_D: Bitboard = FILE_A >> 3;
pub const FILE_E: Bitboard = FILE_A >> 4;
pub const FILE_F: Bitboard = FILE_A >> 5;
pub const FILE_G: Bitboard = FILE_A >> 6;
pub const FILE_H: Bitboard = FILE_A >> 7;

pub const ROW_1: Bitboard = 0x_0000_0000_0000_00FF;
pub const ROW_2: Bitboard = 0x_0000_0000_0000_FF00;
pub const ROW_3: Bitboard = 0x_0000_0000_00FF_0000;
pub const ROW_4: Bitboard = 0x_0000_0000_FF00_0000;
pub const ROW_5: Bitboard = 0x_0000_00FF_0000_0000;
pub const ROW_6: Bitboard = 0x_0000_FF00_0000_0000;
pub const ROW_7: Bitboard = 0x_00FF_0000_0000_0000;
pub const ROW_8: Bitboard = 0x_FF00_0000_0000_0000;

pub fn print_board(b: Bitboard) -> String {
    (0..8).map(
        |i| (0..8).map(
            |j| if bit_at(b, i, j) { "x " }
                else { ". " }
        ).collect::<String>() + "\n"
    ).collect::<String>()
}
pub fn bit_at(b: Bitboard, i: i32, j: i32) -> bool {
    ((b >> (7 - i) * 8 + 7 - j) & 1) == 1
}
pub fn from_square(s: Square) -> Bitboard {
    return 1_u64 << s;
}

pub fn north_one(b: Bitboard) -> Bitboard {
    b << 8
}

pub fn south_one(b: Bitboard) -> Bitboard {
    b >> 8
}

pub fn west_one(b: Bitboard) -> Bitboard {
    (b << 1) & !FILE_H
}

pub fn east_one(b: Bitboard) -> Bitboard {
    (b >> 1) & !FILE_A
}

pub fn northeast_one(b: Bitboard) -> Bitboard {
    (b << 7) & !FILE_A
}

pub fn northwest_one(b: Bitboard) -> Bitboard {
    (b << 9) & !FILE_H
}

pub fn southwest_one(b: Bitboard) -> Bitboard {
    (b >> 7) & !FILE_H
}

pub fn southeast_one(b: Bitboard) -> Bitboard {
    (b >> 9) & !FILE_A
}

pub fn generate_sliding_destinations(occupied: Bitboard,
                               piece: Bitboard, straight: bool,
                               diagonal: bool,
                               only_occ: bool) -> Bitboard {

    let straights = [north_one, south_one, east_one, west_one];
    let diagonals = [northeast_one, southeast_one, northwest_one, southwest_one];

    let mut result: Bitboard = 0;
    if straight {
        for direction in straights.iter() {
            result |= generate_direction(piece, *direction, occupied, only_occ);
        }
    }
    if diagonal {
        for direction in diagonals.iter() {
            result |= generate_direction(piece, *direction, occupied, only_occ);
        }
    }

    return result;
}

fn generate_direction(piece: Bitboard, direction: fn(Bitboard) -> Bitboard, occupied: Bitboard, only_occ: bool) -> Bitboard {
    let mut result: Bitboard = 0;
    let mut b = piece;
    loop {
        b = direction(b);
        if only_occ && direction(b) == 0 {
            break;
        }
        result |= b;
        if b & occupied != 0 || b == 0 {
            break;
        }
    }
    return result;
}


fn extract(mut mask: Bitboard, index: u32) -> Bitboard {
    let required_bits = mask.count_ones();
    let mut result: Bitboard = 0;
    for bit in 0..required_bits {
        let idx = mask.trailing_zeros();
        if (1 << bit) & index != 0 {
            result |= 1 << idx;
        }
        mask = mask & !(1 << idx);
    }
    return result;
}

pub fn low_one_random() -> u64 {
    let mut rng = rand::thread_rng();
    rng.next_u64() & rng.next_u64() & rng.next_u64()
}

fn create_magic_slider(s: Square, diagonal: bool) -> Option<(u64, u64, Vec<Bitboard>)> {
    let bitboard = from_square(s);
    let patt = generate_sliding_destinations(0, bitboard, !diagonal, diagonal, true);

    if !diagonal {
        println!("{}", print_board(bitboard));
        println!("{}\n\n", print_board(patt));
        //panic!("asas");
    }

    let required_bits = patt.count_ones();

    let mut correct_results: [Bitboard; 4096] = [!0_u64; 4096];
    for i in 0..4096_usize {
        let b = extract(patt, i as _);
        correct_results[i] = generate_sliding_destinations(b, bitboard, !diagonal, diagonal, false);
    }

    for _ in 0..10000000 {
        let mut table: [Bitboard; 4096] = [0; 4096];
        let magic = low_one_random();
        let mut failed = false;

        if magic.wrapping_mul(patt).wrapping_shr(48).count_ones() < 6 {
            continue;
        } 

        for i in 0..(1 << required_bits) {
            let b = extract(patt, i as _);
            let correct_result =
                if correct_results[i] != !0_u64 {
                    correct_results[i]
                }
                else {
                    correct_results[i] = generate_sliding_destinations(b, bitboard, !diagonal, diagonal, false);
                    correct_results[i]
                };

            let index = magic.wrapping_mul(b).wrapping_shr(64 - required_bits) as usize;
            if index >= 4096 {
                println!("{}", print_board(patt));
                panic!("asas");
            }
            if table[index] == 0 {
                table[index] = correct_result;
            }
            else if table[index] != correct_result {
                failed = true;
                break;
            }

            //println!("{}", print_board(b));
        }

        if !failed {
            return Some((patt, magic, table.to_vec()));
        }

    }

    println!("OH NOO\n{}", print_board(patt));
    return None;
}


fn create_tables(file: &mut File) {
    writeln!(file, "use crate::bitboard::*;").unwrap();
    create(file, false, "ROOK");
    create(file, true, "BISHOP");
}

fn create(file: &mut File, diagonal: bool, prefix: &str) {
    let mut masks: Vec<u64> = Vec::new();
    let mut magics: Vec<u64> = Vec::new();
    let mut tables: Vec<Vec<u64>> = Vec::new();
    // bishops:
    for square in 0..64 {
        let res = create_magic_slider(square, diagonal);
        if let Some((mask, magic, table)) = res {
            masks.push(mask);
            magics.push(magic);
            tables.push(table);
        }
        else {
            panic!("NO");
        }
    }

    write!(file, "pub const {}_MASKS_AND_MAGICS: [(Bitboard, Bitboard); 64] = [", prefix).unwrap();
    for square in 0..64 {
        write!(file, "({}, {}), ", masks[square], magics[square]).unwrap();
    }
    writeln!(file, "];").unwrap();

    write!(file, "pub const {}_BITS: [u8; 64] = [", prefix).unwrap();
    for square in 0..64 {
        write!(file, "{}, ", masks[square].count_ones()).unwrap();
    }
    writeln!(file, "];").unwrap();

    let max_tablesize = 1_usize << masks.iter().map(|x| x.count_ones()).max().unwrap_or(12);

    write!(file, "pub const {}_TABLES: [[Bitboard; {}]; 64] = [", prefix, max_tablesize).unwrap();
    for square in 0..64 {
        write!(file, "[").unwrap();
        //for result in &tables[square] {
        for i in 0..max_tablesize {
            write!(file, "{}, ", tables[square][i]).unwrap();
        }
        writeln!(file, "], ").unwrap();
    }
    writeln!(file, "];").unwrap();
    //panic!("asas");
}