数论中有个定理,c = 偶数 a mod 奇数 b,若 c 为奇数,则 a>b,若 c 为偶数,则 a<b
推导过程
脚本:
1 2 3 4 5 6 7 8 9 10
L = 0 H = n t = pow(2, e, n) for _ in range(n.bit_length()): c = (t * c) % n if oracle(c) == 0: H = (L + H) // 2 else: L = (L + H) // 2 m = L # plain text
# -*- coding: utf-8 -*- #/usr/bin/env python from pwn import * import libnum import Crypto import re from binascii import hexlify,unhexlify if len(sys.argv)>1: p=remote("127.0.0.1",2334) else: p=remote('127.0.0.1',2333) #context.log_level = 'debug' def oracle(c): l = [] for i in range(20): p.sendline(str(c)) s = p.recvuntil("temp_c") l.append(int(re.findall("l\s*=\s*([0-9]*)",s)[0])) flag0 = 0 flag2 = 0 for i in range(20): if l[i]%2 != 0: flag0 = 1 if l[i] > 10000: flag2 = 1 return [flag2,flag0] def main(): ss = p.recvuntil("temp_c") N = int(re.findall("N\s*=\s*(\d+)",ss)[0]) e = int(re.findall("e\s*=\s*(\d+)",ss)[0]) c = int(re.findall("c\s*=\s*(\d+)",ss)[0]) size = libnum.len_in_bits(N) print "N=",N print "e=",e print "c=",c c = (pow(2,e,N)*c)%N LB = 0 UB = N i = 1 while LB!=UB: flag = oracle(c) print i,flag if flag[1]%2==0: UB = (LB+UB)/2 else: LB = (LB+UB)/2 c = (pow(2,e,N)*c)%N i += 1 print LB print UB for i in range(-128,128,0): LB += i if pow(LB,e,N)==C: print unhexlify(hex(LB)[2:-1]) exit(0) if __name__ == '__main__': main() p.interactive()