ttable.rs 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. use game::Game;
  2. use evaluate::PosValue;
  3. use static_assertions::const_assert_eq;
  4. use std::collections::{HashMap};
  5. use log::info;
  6. use zobrist;
  7. use crate::movegen::{Move, SimpleMove};
  8. use crate::zobrist::Hash;
  9. #[derive(Clone, PartialEq)]
  10. pub enum EntryType {
  11. Value,
  12. LowerBound,
  13. UpperBound,
  14. }
  15. #[derive(Clone)]
  16. pub struct CacheEntry {
  17. pub entry_type: EntryType, // 1 byte
  18. pub halfmove_age: u8, // 1 byte
  19. pub mov: SimpleMove, // 2 bytes
  20. pub value_and_depth: i32, // 4 bytes
  21. }
  22. const_assert_eq!(std::mem::size_of::<CacheEntry>(), 8);
  23. impl CacheEntry {
  24. pub fn null() -> Self {
  25. CacheEntry { entry_type: EntryType::Value, halfmove_age: 0, mov: SimpleMove { from: 0, to: 0 }, value_and_depth: 0 }
  26. }
  27. pub fn new_value(depth: u8, halfmove_age: u8, mov: SimpleMove, value: PosValue) -> Self {
  28. CacheEntry {
  29. entry_type: EntryType::Value,
  30. halfmove_age,
  31. mov,
  32. value_and_depth: (value << 8) | (depth as i32),
  33. }
  34. }
  35. pub fn new_upper(depth: u8, halfmove_age: u8, mov: SimpleMove, value: PosValue) -> Self {
  36. CacheEntry {
  37. entry_type: EntryType::UpperBound,
  38. halfmove_age,
  39. mov,
  40. value_and_depth: (value << 8) | (depth as i32),
  41. }
  42. }
  43. pub fn new_lower(depth: u8, halfmove_age: u8, mov: SimpleMove, value: PosValue) -> Self {
  44. CacheEntry {
  45. entry_type: EntryType::LowerBound,
  46. halfmove_age,
  47. mov,
  48. value_and_depth: (value << 8) | (depth as i32),
  49. }
  50. }
  51. pub fn value(&self) -> PosValue {
  52. self.value_and_depth >> 8
  53. }
  54. pub fn depth(&self) -> u8 {
  55. (self.value_and_depth & 0xFF) as u8
  56. }
  57. }
  58. struct Bucket {
  59. hashes: [Hash; 4],
  60. entries: [CacheEntry; 4],
  61. }
  62. const_assert_eq!(std::mem::size_of::<Bucket>(), 64);
  63. pub struct Cache {
  64. table: Vec<Bucket>,
  65. }
  66. #[derive(Clone)]
  67. pub struct RepetitionTable {
  68. hashmap: HashMap<zobrist::Hash, i32>,
  69. }
  70. impl Cache {
  71. pub fn megabytes(mbytes: usize) -> Self {
  72. Self::new(mbytes * 1024 * 1024 / std::mem::size_of::<Bucket>())
  73. }
  74. pub fn new(length: usize) -> Self {
  75. let c = Cache {
  76. table: unsafe {
  77. let layout = std::alloc::Layout::array::<Bucket>(length).unwrap();
  78. let allayout = layout.align_to(64).unwrap();
  79. let ptr = std::alloc::alloc(allayout);
  80. Vec::from_raw_parts(ptr.cast(), length, length)
  81. },
  82. };
  83. c
  84. }
  85. fn get_index(&self, hash: zobrist::Hash) -> usize {
  86. (hash % (self.table.len() as u64)) as usize
  87. }
  88. pub fn lookup<'a>(&'a self, board: &Game) -> Option<&'a CacheEntry> {
  89. if board.zobrist.is_none() {
  90. info!("invalid zobrist");
  91. }
  92. let hash = board.zobrist.as_ref().unwrap().1;
  93. let index = self.get_index(hash);
  94. let bucket = &self.table[index];
  95. for i in 0..bucket.hashes.len() {
  96. if bucket.hashes[i] == hash {
  97. return Some(&bucket.entries[i]);
  98. }
  99. }
  100. return None;
  101. }
  102. fn should_replace(old_hash: Hash, old_ce: &CacheEntry, new_hash: Hash, new: &CacheEntry) -> bool {
  103. if old_hash == 0 {
  104. return true;
  105. }
  106. // different positions
  107. if old_hash != new_hash {
  108. if old_ce.halfmove_age < new.halfmove_age {
  109. return true;
  110. }
  111. if old_ce.entry_type == EntryType::Value {
  112. return false;
  113. }
  114. if old_ce.depth() <= new.depth() {
  115. return true;
  116. }
  117. }
  118. if new.depth() >= old_ce.depth() {
  119. return true;
  120. }
  121. return false;
  122. }
  123. pub fn cache(&mut self, game_pos: &Game, ce: CacheEntry) {
  124. if game_pos.zobrist.is_none() {
  125. info!("invalid zobrist");
  126. }
  127. let hash = game_pos.zobrist.as_ref().unwrap().1;
  128. let index = self.get_index(hash);
  129. let bucket = &mut self.table[index];
  130. for i in 0..bucket.hashes.len() {
  131. if bucket.hashes[i] == hash {
  132. bucket.entries[i] = ce;
  133. return;
  134. }
  135. }
  136. for i in 0..bucket.hashes.len() {
  137. if Self::should_replace(bucket.hashes[i], &bucket.entries[i], hash, &ce) {
  138. bucket.hashes[i] = hash;
  139. bucket.entries[i] = ce;
  140. break;
  141. }
  142. }
  143. }
  144. pub fn fullness_permill(&self) -> usize {
  145. let mut permill = 0;
  146. for i in 0..25 {
  147. permill += 10 * self.table[i].hashes.iter().filter(|h| **h != 0).count();
  148. }
  149. permill
  150. }
  151. }
  152. impl RepetitionTable {
  153. pub fn new() -> Self {
  154. RepetitionTable {
  155. hashmap: HashMap::with_capacity(1024)
  156. }
  157. }
  158. pub fn clear(&mut self) {
  159. self.hashmap.clear();
  160. }
  161. pub fn increment(&mut self, hash: zobrist::Hash) -> i32 {
  162. if let Some(entry) = self.hashmap.get_mut(&hash) {
  163. *entry += 1;
  164. *entry
  165. }
  166. else {
  167. self.hashmap.insert(hash, 1);
  168. 1
  169. }
  170. }
  171. pub fn decrement(&mut self, hash: zobrist::Hash) {
  172. if let Some(entry) = self.hashmap.get_mut(&hash) {
  173. *entry -= 1;
  174. if *entry == 0 {
  175. self.hashmap.remove(&hash);
  176. }
  177. }
  178. }
  179. }