use board::Board; use evaluate::PosValue; use static_assertions::const_assert_eq; use std::collections::{HashMap}; use log::info; use zobrist; use crate::movegen::{SimpleMove}; use crate::zobrist::Hash; #[derive(Clone, PartialEq)] pub enum EntryType { Value, LowerBound, UpperBound, NoBound } #[derive(Clone)] pub struct CacheEntry { pub entry_type: EntryType, // 1 byte pub halfmove_age: u8, // 1 byte pub mov: SimpleMove, // 2 bytes pub value_and_depth: i32, // 4 bytes } const_assert_eq!(std::mem::size_of::(), 8); impl CacheEntry { pub fn null() -> Self { CacheEntry { entry_type: EntryType::Value, halfmove_age: 0, mov: SimpleMove { from: 0, to: 0 }, value_and_depth: 0 } } pub fn new_value(depth: u8, halfmove_age: u8, mov: SimpleMove, value: PosValue) -> Self { CacheEntry { entry_type: EntryType::Value, halfmove_age, mov, value_and_depth: (value << 8) | (depth as i32), } } pub fn new_upper(depth: u8, halfmove_age: u8, mov: SimpleMove, value: PosValue) -> Self { CacheEntry { entry_type: EntryType::UpperBound, halfmove_age, mov, value_and_depth: (value << 8) | (depth as i32), } } pub fn new_lower(depth: u8, halfmove_age: u8, mov: SimpleMove, value: PosValue) -> Self { CacheEntry { entry_type: EntryType::LowerBound, halfmove_age, mov, value_and_depth: (value << 8) | (depth as i32), } } pub fn value(&self) -> PosValue { self.value_and_depth >> 8 } pub fn depth(&self) -> u8 { (self.value_and_depth & 0xFF) as u8 } } struct Bucket { hashes: [Hash; 4], entries: [CacheEntry; 4], } const_assert_eq!(std::mem::size_of::(), 64); pub struct Cache { table: Vec, } #[derive(Clone)] pub struct RepetitionTable { hashmap: HashMap, } impl Cache { pub fn new_in_megabytes(mbytes: usize) -> Self { Self::new(mbytes * 1024 * 1024 / std::mem::size_of::()) } pub fn new(length: usize) -> Self { let c = Cache { table: unsafe { let layout = std::alloc::Layout::array::(length).unwrap(); let allayout = layout.align_to(64).unwrap(); let ptr = std::alloc::alloc(allayout); Vec::from_raw_parts(ptr.cast(), length, length) }, }; c } fn get_index(&self, hash: zobrist::Hash) -> usize { (hash % (self.table.len() as u64)) as usize } pub fn lookup<'a>(&'a self, board: &Board) -> Option<&'a CacheEntry> { if board.zobrist.is_none() { info!("invalid zobrist"); } let hash = board.zobrist.as_ref().unwrap().1; let index = self.get_index(hash); let bucket = &self.table[index]; for i in 0..bucket.hashes.len() { if bucket.hashes[i] == hash { return Some(&bucket.entries[i]); } } return None; } fn should_replace(old_hash: Hash, old_ce: &CacheEntry, new_hash: Hash, new: &CacheEntry) -> bool { if old_hash == 0 { return true; } // different positions if old_hash != new_hash { if old_ce.halfmove_age < new.halfmove_age { return true; } if old_ce.entry_type == EntryType::Value { return false; } if old_ce.depth() <= new.depth() { return true; } } if new.depth() >= old_ce.depth() { return true; } return false; } pub fn cache(&mut self, game_pos: &Board, ce: CacheEntry) { if game_pos.zobrist.is_none() { info!("invalid zobrist"); } let hash = game_pos.zobrist.as_ref().unwrap().1; let index = self.get_index(hash); let bucket = &mut self.table[index]; for i in 0..bucket.hashes.len() { if bucket.hashes[i] == hash { if bucket.entries[i].depth() <= ce.depth() { bucket.entries[i] = ce; } return; } } for i in 0..bucket.hashes.len() { if Self::should_replace(bucket.hashes[i], &bucket.entries[i], hash, &ce) { bucket.hashes[i] = hash; bucket.entries[i] = ce; break; } } } pub fn fullness_permill(&self) -> usize { let mut permill = 0; for i in 0..25 { permill += 10 * self.table[i].hashes.iter().filter(|h| **h != 0).count(); } permill } pub fn clear(&mut self) { for bucket in &mut self.table { bucket.hashes = [0; 4]; } } } impl RepetitionTable { pub fn new() -> Self { RepetitionTable { hashmap: HashMap::with_capacity(64) } } pub fn clear(&mut self) { self.hashmap.clear(); } pub fn lookup(&self, hash: zobrist::Hash) -> i32 { *self.hashmap.get(&hash).unwrap_or(&0) } pub fn increment(&mut self, hash: zobrist::Hash) -> i32 { if let Some(entry) = self.hashmap.get_mut(&hash) { *entry += 1; *entry } else { self.hashmap.insert(hash, 1); 1 } } pub fn decrement(&mut self, hash: zobrist::Hash) { if let Some(entry) = self.hashmap.get_mut(&hash) { *entry -= 1; if *entry == 0 { self.hashmap.remove(&hash); } } } }