search.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. use std::clone;
  2. use movegen::*;
  3. use game::Game;
  4. use evaluate::*;
  5. use log::info;
  6. use rand::prelude::*;
  7. use ttable::*;
  8. ///
  9. /// struct to contain data for a search
  10. ///
  11. pub struct SearchControl<'a> {
  12. /// node counter
  13. pub nodes: usize,
  14. pub killer_moves: Vec<[Move; 2]>,
  15. pub last_move: Option<Move>,
  16. pub countermoves: [[Move; 64]; 64],
  17. /// current halfmove clock for discarding old hash entries
  18. pub halfmove_age: u16,
  19. /// function to check if the search should be exited
  20. pub check: &'a mut dyn FnMut() -> bool,
  21. pub stopping: bool,
  22. pub move_history: &'a mut RepetitionTable,
  23. /// depth the search was started at
  24. initial_depth: i32,
  25. nullmoves_enabled: bool,
  26. }
  27. pub enum SearchResult {
  28. Finished(Move, PosValue),
  29. Cancelled(Option<(Move, PosValue)>),
  30. Invalid
  31. }
  32. impl<'a> SearchControl<'a> {
  33. pub fn new(check: &'a mut dyn FnMut() -> bool, move_history: &'a mut RepetitionTable, depth: i32) -> Self {
  34. SearchControl {
  35. nodes: 0,
  36. killer_moves: vec![[Move::nullmove(); 2]; depth as usize],
  37. last_move: None,
  38. countermoves: [[Move::nullmove(); 64]; 64],
  39. halfmove_age: 0,
  40. check,
  41. stopping: false,
  42. move_history,
  43. initial_depth: depth,
  44. nullmoves_enabled: true
  45. }
  46. }
  47. pub fn set_depth(&mut self, depth: i32) {
  48. self.initial_depth = depth;
  49. self.killer_moves = vec![[Move::nullmove(); 2]; depth as usize];
  50. }
  51. pub fn insert_killer(&mut self, ply_depth: usize, killer: Move) {
  52. if self.is_killer(ply_depth, &killer) {
  53. return;
  54. }
  55. let nkm = self.killer_moves[ply_depth].len();
  56. for i in 1..nkm {
  57. self.killer_moves[ply_depth][i - 1] = self.killer_moves[ply_depth][i];
  58. }
  59. self.killer_moves[ply_depth][nkm - 1] = killer;
  60. }
  61. pub fn is_killer(&self, ply_depth: usize, killer: &Move) -> bool {
  62. self.killer_moves[ply_depth].contains(killer)
  63. }
  64. pub fn countermove_to(&self, last_move: Move) -> Move {
  65. let from = last_move.to_simple().from;
  66. let to = last_move.to_simple().to;
  67. self.countermoves[from as usize][to as usize]
  68. }
  69. }
  70. /**
  71. * searches for moves and returns the best move found plus its value
  72. */
  73. pub fn search(game: &mut Game, sc: &mut SearchControl, hash: &mut Cache, mut alpha: PosValue, beta: PosValue, depth: i32) -> SearchResult {
  74. if depth == 0 {
  75. return SearchResult::Invalid;
  76. }
  77. let cache_entry = hash.lookup(game);
  78. let ply_depth = (sc.initial_depth - depth) as usize;
  79. let moves = generate_legal_sorted_moves(game, hash, &sc.killer_moves[ply_depth], cache_entry.map(CacheEntry::clone), false, game.turn);
  80. //let mut moves = generate_legal_moves(game, game.turn);
  81. //sort_moves(game, hash, &sc.killer_moves[ply_depth], &mut moves);
  82. info!("moves: {:?}", moves.iter().map(|mv| mv.to_string()).collect::<Vec<String>>());
  83. // use a slight offset for the alpha value in the root node in order to
  84. // determine possibly multiple good moves
  85. const ALPHA_OFFSET: PosValue = 0 as PosValue;
  86. let mut valued_moves: Vec<(Move, PosValue)> = Vec::with_capacity(moves.len());
  87. let mut cancelled = false;
  88. for mov in moves {
  89. let undo = game.apply(mov);
  90. let val = -negamax(game, sc, hash, decrease_mate_in(-beta), decrease_mate_in(-alpha), depth - 1);
  91. //info!("moveval {} -> {}\n", mov.to_string(), val);
  92. game.undo_move(undo);
  93. if sc.stopping {
  94. cancelled = true;
  95. break;
  96. }
  97. if increase_mate_in(val) > alpha {
  98. alpha = increase_mate_in(val) - ALPHA_OFFSET;
  99. valued_moves.push((mov, -alpha));
  100. }
  101. }
  102. valued_moves.sort_by_key(|mv| mv.1);
  103. if valued_moves.len() > 0 {
  104. let min_val = valued_moves[0].1;
  105. let best_moves = valued_moves.iter().filter(|mv| mv.1 == min_val).collect::<Vec<&(Move, PosValue)>>();
  106. let mut rng = rand::thread_rng();
  107. let chosen_mov = best_moves[(rng.next_u64() % best_moves.len() as u64) as usize];
  108. if cancelled {
  109. return SearchResult::Cancelled(Some((chosen_mov.0, -chosen_mov.1)));
  110. }
  111. else {
  112. hash.cache(game, CacheEntry::new_value(depth as _, sc.halfmove_age as _, chosen_mov.0.to_simple(), chosen_mov.1 as _));
  113. return SearchResult::Finished(chosen_mov.0, -chosen_mov.1);
  114. }
  115. }
  116. else {
  117. return SearchResult::Invalid;
  118. }
  119. }
  120. pub fn negamax(game: &mut Game, sc: &mut SearchControl, hash: &mut Cache, mut alpha: PosValue, beta: PosValue, mut depth: i32) -> PosValue {
  121. let last_move = sc.last_move;
  122. // we can't beat an alpha this good
  123. if alpha >= mate_in_p1(1) {
  124. return alpha;
  125. }
  126. let cache_entry = hash.lookup(game);
  127. if let Some(e) = &cache_entry {
  128. if e.depth() as i32 >= depth {
  129. //println!("TABLE HIT!");
  130. match e.entry_type {
  131. EntryType::Value => { return e.value(); },
  132. EntryType::LowerBound => {
  133. if e.value() >= beta { return beta; }
  134. },
  135. EntryType::UpperBound => {
  136. if e.value() < alpha { return alpha; }
  137. },
  138. }
  139. }
  140. }
  141. if depth == 0 {
  142. return quiescence_search(game, sc, hash, alpha, beta, 9);
  143. }
  144. sc.nodes += 1;
  145. if sc.nodes % 1024 == 0 {
  146. if (sc.check)() {
  147. sc.stopping = true;
  148. return 0;
  149. }
  150. }
  151. let ply_depth = (sc.initial_depth - depth) as usize;
  152. /*let moves = generate_legal_sorted_moves(
  153. game,
  154. hash,
  155. &sc.killer_moves[ply_depth],
  156. cache_entry,
  157. game.turn);*/
  158. let mut moves = MoveGenerator::generate_legal_moves(
  159. game,
  160. cache_entry,
  161. &sc.killer_moves[ply_depth],
  162. last_move.map(|lm| sc.countermove_to(lm)),
  163. game.turn
  164. );
  165. //info!("nega moves: {:?}", moves.iter().map(|mv| mv.to_string()).collect::<Vec<String>>());
  166. let check = is_check(game, game.turn);
  167. if moves.is_empty() {
  168. if check {
  169. // mate
  170. return checkmated();
  171. }
  172. else {
  173. // stalemate
  174. return 0;
  175. }
  176. }
  177. // Nullmove
  178. if sc.nullmoves_enabled && !check && depth >= 4 && game.get_all_side(game.turn).count_ones() > 5 {
  179. let reduce = if depth > 5 { if depth > 7 { 5 } else { 4 }} else { 3 };
  180. let nmov = Move::nullmove();
  181. let undo = game.apply(nmov);
  182. sc.nullmoves_enabled = false;
  183. let val = -negamax(game, sc, hash, -beta, -alpha, depth - reduce);
  184. sc.nullmoves_enabled = true;
  185. game.undo_move(undo);
  186. if is_mate_in_p1(val).is_none() {
  187. if val >= beta {
  188. depth -= reduce - 1;
  189. //return beta;
  190. }
  191. }
  192. }
  193. let mut best_move: Option<Move> = None;
  194. while let Some(mov) = moves.next() {
  195. //println!("mov: {}", mov.to_string());
  196. let undo = game.apply(mov);
  197. sc.last_move = Some(mov);
  198. let val = increase_mate_in(
  199. -negamax(game, sc, hash, -decrease_mate_in(beta), -decrease_mate_in(alpha), depth - 1)
  200. );
  201. game.undo_move(undo);
  202. // return if the search has been cancelled
  203. if sc.stopping {
  204. return alpha;
  205. }
  206. if val >= beta {
  207. hash.cache(game, CacheEntry::new_lower(depth as _, sc.halfmove_age as _, mov.to_simple(), val));
  208. if !mov.is_capture() {
  209. sc.insert_killer(ply_depth, mov);
  210. if depth >= 2 {
  211. sc.countermoves[mov.to_simple().from as usize][mov.to_simple().to as usize] = mov;
  212. }
  213. }
  214. return val;
  215. }
  216. if val > alpha {
  217. alpha = val;
  218. best_move = Some(mov);
  219. if depth >= 3 {
  220. sc.countermoves[mov.to_simple().from as usize][mov.to_simple().to as usize] = mov;
  221. }
  222. if alpha >= mate_in_p1(1) {
  223. break;
  224. }
  225. }
  226. if let Some(lm) = last_move {
  227. moves.update_counter(sc.countermove_to(lm));
  228. }
  229. }
  230. if let Some(mov) = best_move {
  231. // alpha is exact
  232. hash.cache(game, CacheEntry::new_value(depth as _, sc.halfmove_age as _, mov.to_simple(), alpha));
  233. let cur_depth = (sc.initial_depth - depth) as usize;
  234. }
  235. else {
  236. hash.cache(game, CacheEntry::new_upper(depth as _, sc.halfmove_age as _, SimpleMove { from: 0, to: 0 }, alpha));
  237. }
  238. //info!("best alpha {}", alpha);
  239. return alpha;
  240. }
  241. fn quiescence_search(game: &mut Game, sc: &mut SearchControl, hash: &mut Cache, mut alpha: PosValue, beta: PosValue, depth: i32) -> PosValue {
  242. sc.nodes += 1;
  243. if sc.nodes % 1024 == 0 {
  244. if (sc.check)() {
  245. sc.stopping = true;
  246. return 0;
  247. }
  248. }
  249. let val = evaluate(game);
  250. if val >= beta {
  251. return beta;
  252. }
  253. if val > alpha {
  254. alpha = val;
  255. }
  256. if depth <= 0 {
  257. return alpha;
  258. }
  259. if game.get_piece(KING, game.turn) == 0 { return checkmated(); }
  260. //let mut moves = generate_legal_sorted_moves(game, hash, &[], None, true, game.turn);
  261. let mut moves = generate_legal_moves(game, game.turn, true);
  262. //sort_moves_no_hash(game, &mut moves);
  263. sort_moves_least_valuable_attacker(game, &mut moves);
  264. for mov in moves {
  265. let undo = game.apply(mov);
  266. let val = -quiescence_search(game, sc, hash, decrease_mate_in(-beta), decrease_mate_in(-alpha), depth - 1);
  267. game.undo_move(undo);
  268. if sc.stopping {
  269. return alpha
  270. }
  271. if val >= beta {
  272. return beta;
  273. }
  274. if increase_mate_in(val) > alpha {
  275. alpha = increase_mate_in(val);
  276. }
  277. }
  278. return alpha;
  279. }
  280. pub fn perft(game: &mut Game, sc: &mut SearchControl, depth: i32) -> bool {
  281. let moves = generate_legal_moves(game, game.turn, false);
  282. if depth <= 1 {
  283. sc.nodes += moves.len();
  284. if sc.nodes % 1024 < moves.len() {
  285. if (sc.check)() {
  286. return true;
  287. }
  288. }
  289. return false;
  290. }
  291. for mov in moves {
  292. let nodes_before = sc.nodes;
  293. let undo = game.apply(mov);
  294. let do_return = perft(game, sc, depth - 1);
  295. game.undo_move(undo);
  296. if depth >= sc.initial_depth {
  297. println!("{}: {}", mov.to_string(), sc.nodes - nodes_before);
  298. }
  299. if do_return {
  300. return true;
  301. }
  302. }
  303. return false;
  304. }
  305. #[cfg(test)]
  306. mod tests {
  307. use super::*;
  308. #[test]
  309. fn test_move_generation() {
  310. let positions = [
  311. "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8",
  312. "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10"
  313. ];
  314. let perft_results: [Vec<usize>; 2] = [
  315. vec![44, 1486, 62379, 2103487],
  316. vec![46, 2079, 89890, 3894594]
  317. ];
  318. for (i, &position) in positions.iter().enumerate() {
  319. let mut game = Game::from_fen_str(position).unwrap();
  320. for (j, &p_res) in perft_results[i].iter().enumerate() {
  321. let depth = j + 1;
  322. let mut check_fn = || false;
  323. let mut rt = RepetitionTable::new();
  324. let mut sc = SearchControl::new(&mut check_fn, &mut rt, depth as _);
  325. perft(&mut game, &mut sc, depth as _);
  326. assert_eq!(sc.nodes, p_res);
  327. }
  328. }
  329. }
  330. }