search.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. use movegen::*;
  2. use game::Game;
  3. use evaluate::*;
  4. use log::info;
  5. use rand::prelude::*;
  6. use hash::*;
  7. pub struct SearchControl<'a> {
  8. /// node counter
  9. pub nodes: usize,
  10. pub pv: Vec<Move>,
  11. pub killer_moves: Vec<[Move; 2]>,
  12. /// function to check if the search should be exited
  13. pub check: &'a mut dyn FnMut() -> bool,
  14. pub stopping: bool,
  15. pub move_history: &'a mut RepetitionTable,
  16. /// depth the search was started at
  17. pub initial_depth: i32,
  18. }
  19. pub enum SearchResult {
  20. Finished(Move, PosValue),
  21. Cancelled(Option<(Move, PosValue)>),
  22. Invalid
  23. }
  24. impl<'a> SearchControl<'a> {
  25. pub fn new(check: &'a mut dyn FnMut() -> bool, move_history: &'a mut RepetitionTable, depth: i32) -> Self {
  26. SearchControl {
  27. nodes: 0,
  28. pv: Vec::with_capacity(depth as usize),
  29. killer_moves: vec![[Move::Nullmove; 2]; depth as usize],
  30. check,
  31. stopping: false,
  32. move_history,
  33. initial_depth: depth
  34. }
  35. }
  36. pub fn set_depth(&mut self, depth: i32) {
  37. self.initial_depth = depth;
  38. self.killer_moves = vec![[Move::Nullmove; 2]; depth as usize];
  39. self.pv = vec![Move::Nullmove; depth as _];
  40. }
  41. pub fn insert_killer(&mut self, ply_depth: usize, killer: Move) {
  42. if self.is_killer(ply_depth, &killer) {
  43. return;
  44. }
  45. let nkm = self.killer_moves[ply_depth].len();
  46. for i in 1..nkm {
  47. self.killer_moves[ply_depth][i - 1] = self.killer_moves[ply_depth][i];
  48. }
  49. self.killer_moves[ply_depth][nkm - 1] = killer;
  50. }
  51. pub fn is_killer(&self, ply_depth: usize, killer: &Move) -> bool {
  52. self.killer_moves[ply_depth].contains(killer)
  53. }
  54. }
  55. /**
  56. * searches for moves and returns the best move found plus its value
  57. */
  58. pub fn search(game: &mut Game, sc: &mut SearchControl, hash: &mut Cache, mut alpha: PosValue, beta: PosValue, depth: i32) -> SearchResult {
  59. if depth == 0 {
  60. return SearchResult::Invalid;
  61. }
  62. let mut moves = generate_legal_moves(game, game.turn);
  63. let mut rng = rand::thread_rng();
  64. let ply_depth = (sc.initial_depth - depth) as usize;
  65. sort_moves(game, hash, &sc.killer_moves[ply_depth], &mut moves);
  66. info!("moves: {:?}", moves.iter().map(|mv| mv.to_string()).collect::<Vec<String>>());
  67. //let mut alpha: PosValue = MIN_VALUE;
  68. //let mut beta: PosValue = MAX_VALUE;
  69. // use a slight offset for the alpha value in the root node in order to
  70. // determine possibly multiple good moves
  71. const ALPHA_OFFSET: PosValue = 0 as PosValue;
  72. let mut valued_moves: Vec<(Move, PosValue)> = Vec::with_capacity(moves.len());
  73. let mut cancelled = false;
  74. for mov in moves {
  75. let undo = game.apply(mov);
  76. //assert_eq!(new_game, *game, );
  77. //info!("searching {}", mov.to_string());
  78. let val = -negamax(game, sc, hash, decrease_mate_in(-beta), decrease_mate_in(-alpha), depth - 1);
  79. game.undo_move(undo);
  80. if sc.stopping {
  81. //return (Move::default(), 0);
  82. cancelled = true;
  83. break;
  84. }
  85. if val >= mate_in_p1(1) {
  86. // mate in 1 --- can't get better than that
  87. info!("yay mate!");
  88. //return SearchResult::Finished(mov, increase_mate_in(val));
  89. }
  90. //info!("searched {} -> {}", mov.to_string(), val);
  91. if increase_mate_in(val) > alpha {
  92. alpha = increase_mate_in(val) - ALPHA_OFFSET;
  93. valued_moves.push((mov, -alpha));
  94. }
  95. }
  96. //info!("movvalues: {:?}", valued_moves.iter().map(|mv| mv.0.to_string() + " - " + &mv.1.to_string()).collect::<Vec<String>>());
  97. valued_moves.sort_by_key(|mv| mv.1);
  98. //info!("best movvalues: {:?}", valued_moves.iter().map(|mv| mv.0.to_string() + " - " + &mv.1.to_string()).collect::<Vec<String>>());
  99. if valued_moves.len() > 0 {
  100. let min_val = valued_moves[0].1;
  101. let best_moves = valued_moves.iter().filter(|mv| mv.1 == min_val).collect::<Vec<&(Move, PosValue)>>();
  102. //info!("bestmove value {}", -min_val);
  103. let chosen_mov = best_moves[(rng.next_u64() % best_moves.len() as u64) as usize];
  104. if cancelled {
  105. return SearchResult::Cancelled(Some((chosen_mov.0, -chosen_mov.1)));
  106. }
  107. else {
  108. hash.cache(game, CacheEntry::new_value(depth, chosen_mov.1));
  109. sc.pv[0] = chosen_mov.0;
  110. return SearchResult::Finished(chosen_mov.0, -chosen_mov.1);
  111. }
  112. }
  113. else {
  114. return SearchResult::Invalid;
  115. }
  116. }
  117. pub fn negamax(game: &mut Game, sc: &mut SearchControl, hash: &mut Cache, mut alpha: PosValue, beta: PosValue, depth: i32) -> PosValue {
  118. if let Some(e) = hash.lookup(game) {
  119. if e.depth >= depth {
  120. //println!("TABLE HIT!");
  121. match e.entry_type {
  122. EntryType::Value => { return e.value; },
  123. //EntryType::Value => { if e.value >= alpha { return (e.value, false); } else { return (alpha, false); } },
  124. //EntryType::LowerBound => { if e.value > alpha { return (e.value, false) } },
  125. EntryType::LowerBound => {
  126. if e.value < alpha { return alpha; }
  127. //if e.value >= beta { return (beta, false); }
  128. },
  129. //EntryType::UpperBound => { if e.value >= beta { return (beta, false); } },
  130. EntryType::UpperBound => {
  131. if e.value >= beta { return beta; }
  132. },
  133. }
  134. }
  135. }
  136. if depth == 0 {
  137. return quiescence_search(game, sc, hash, alpha, beta, 9);
  138. }
  139. sc.nodes += 1;
  140. if sc.nodes % 1024 == 0 {
  141. if (sc.check)() {
  142. sc.stopping = true;
  143. return 0;
  144. }
  145. }
  146. let ply_depth = (sc.initial_depth - depth) as usize;
  147. let moves = generate_legal_sorted_moves(game, hash, &sc.killer_moves[ply_depth], game.turn);
  148. //info!("nega moves: {:?}", moves.iter().map(|mv| mv.to_string()).collect::<Vec<String>>());
  149. let check = is_check(game, game.turn);
  150. if moves.len() == 0 {
  151. if check {
  152. // mate
  153. return -mate();
  154. }
  155. else {
  156. // stalemate
  157. return 0;
  158. }
  159. }
  160. // Nullmove
  161. if false && !check && depth >= 4 && game_lateness(game) < 80 {
  162. let nmov = Move::Nullmove;
  163. let undo = game.apply(nmov);
  164. let val = -negamax(game, sc, hash, -beta, -alpha, depth / 2 - 1);
  165. game.undo_move(undo);
  166. if sc.stopping {
  167. return alpha;
  168. }
  169. if is_mate_in_p1(val).is_none() {
  170. if val >= beta {
  171. return beta;
  172. }
  173. }
  174. }
  175. //sort_moves(game, hash, &sc.killer_moves, &mut moves);
  176. //moves.sort_by_cached_key(|m| if sc.is_killer(*m) { -1 } else { 0 });
  177. let mut alpha_is_exact = false;
  178. let mut best_move = Move::Nullmove;
  179. // we can't beat an alpha this good
  180. if alpha >= mate() {
  181. return alpha;
  182. }
  183. for mov in moves {
  184. let undo = game.apply(mov);
  185. let val = -negamax(game, sc, hash, -decrease_mate_in(beta), -decrease_mate_in(alpha), depth - 1);
  186. game.undo_move(undo);
  187. //val = increase_mate_in(val);
  188. // return if the search has been cancelled
  189. if sc.stopping {
  190. return alpha;
  191. }
  192. if val >= beta {
  193. hash.cache(game, CacheEntry::new_upper(depth, val));
  194. //println!("but ret {}: {}", depth, beta);
  195. //info!("{} causes beta cutoff at {} ", mov.to_string(), val);
  196. if !mov.is_capture() {
  197. sc.insert_killer(ply_depth, mov);
  198. }
  199. return beta;
  200. }
  201. if increase_mate_in(val) > alpha {
  202. alpha = increase_mate_in(val);
  203. alpha_is_exact = true;
  204. best_move = mov;
  205. }
  206. }
  207. if alpha_is_exact {
  208. hash.cache(game, CacheEntry::new_value(depth, alpha));
  209. let cur_depth = (sc.initial_depth - depth) as usize;
  210. sc.pv[cur_depth] = best_move;
  211. }
  212. else {
  213. hash.cache(game, CacheEntry::new_lower(depth, alpha));
  214. }
  215. //info!("best alpha {}", alpha);
  216. return alpha;
  217. }
  218. fn quiescence_search(game: &mut Game, sc: &mut SearchControl, hash: &mut Cache, mut alpha: PosValue, beta: PosValue, depth: i32) -> PosValue {
  219. sc.nodes += 1;
  220. if sc.nodes % 1024 == 0 {
  221. if (sc.check)() {
  222. sc.stopping = true;
  223. return 0;
  224. }
  225. }
  226. let val = evaluate(game);
  227. if val >= beta {
  228. return beta;
  229. }
  230. if val > alpha {
  231. alpha = val;
  232. }
  233. if depth <= 0 {
  234. return alpha;
  235. }
  236. if game.get_piece(KING, game.turn) == 0 { return -mate(); }
  237. let mut moves = generate_attacking_moves(game, game.turn);
  238. sort_moves_no_hash(game, &mut moves);
  239. for mov in moves {
  240. let undo = game.apply(mov);
  241. let val = -quiescence_search(game, sc, hash, decrease_mate_in(-beta), decrease_mate_in(-alpha), depth - 1);
  242. game.undo_move(undo);
  243. if sc.stopping {
  244. return alpha
  245. }
  246. if val >= beta {
  247. return beta;
  248. }
  249. if increase_mate_in(val) > alpha {
  250. alpha = increase_mate_in(val);
  251. }
  252. }
  253. return alpha;
  254. }
  255. pub fn perft(game: &mut Game, sc: &mut SearchControl, depth: i32) -> bool {
  256. let moves = generate_legal_moves(game, game.turn);
  257. if depth <= 1 {
  258. sc.nodes += moves.len();
  259. if sc.nodes % 1024 < moves.len() {
  260. if (sc.check)() {
  261. return true;
  262. }
  263. }
  264. return false;
  265. }
  266. for mov in moves {
  267. let undo = game.apply(mov);
  268. let do_return = perft(game, sc, depth - 1);
  269. game.undo_move(undo);
  270. if do_return {
  271. return true;
  272. }
  273. }
  274. return false;
  275. }
  276. #[cfg(test)]
  277. mod tests {
  278. use super::*;
  279. #[test]
  280. fn test_move_generation() {
  281. let positions = [
  282. "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8",
  283. "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10"
  284. ];
  285. let perft_results: [Vec<usize>; 2] = [
  286. vec![44, 1486, 62379, 2103487],
  287. vec![46, 2079, 89890, 3894594]
  288. ];
  289. for (i, &position) in positions.iter().enumerate() {
  290. let mut game = Game::from_fen_str(position).unwrap();
  291. for (j, &p_res) in perft_results[i].iter().enumerate() {
  292. let depth = j + 1;
  293. let mut check_fn = || false;
  294. let mut rt = RepetitionTable::new();
  295. let mut sc = SearchControl::new(&mut check_fn, &mut rt, depth as _);
  296. perft(&mut game, &mut sc, depth as _);
  297. assert_eq!(sc.nodes, p_res);
  298. }
  299. }
  300. }
  301. }