8000 smt: abstract multiplication by daejunpark · Pull Request #340 · a16z/halmos · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

smt: abstract multiplication #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
264: Function("f_evm_bvurem_264", BitVecSort264, BitVecSort264, BitVecSort264),
512: Function("f_evm_bvurem_512", BitVecSort512, BitVecSort512, BitVecSort512),
}
f_mul = {
256: Function("f_evm_bvmul", BitVecSort256, BitVecSort256, BitVecSort256),
512: Function("f_evm_bvmul_512", BitVecSort512, BitVecSort512, BitVecSort512),
}
f_sdiv = Function("f_evm_bvsdiv", BitVecSort256, BitVecSort256, BitVecSort256)
f_smod = Function("f_evm_bvsrem", BitVecSort256, BitVecSort256, BitVecSort256)
f_exp = Function("f_evm_exp", BitVecSort256, BitVecSort256, BitVecSort256)
Expand Down Expand Up @@ -1346,6 +1350,10 @@ def mk_mod(self, ex: Exec, x: Any, y: Any) -> Any:
# ex.path.append(Or(y == con(0), ULT(term, y))) # (x % y) < y if y != 0
return term

def mk_mul(self, ex: Exec, x: Any, y: Any) -> Any:
term = f_mul[x.size()](x, y)
return term

def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
w1 = b2i(w1)
w2 = b2i(w2)
Expand All @@ -1357,7 +1365,38 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
return w1 - w2

if op == EVM.MUL:
return w1 * w2
is_bv_value_w1 = is_bv_value(w1)
is_bv_value_w2 = is_bv_value(w2)

if is_bv_value_w1 and is_bv_value_w2:
return w1 * w2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to double check: we're fine here in case of overflow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, and it's indeed intended.


if is_bv_value_w1:
i1: int = w1.as_long()
if i1 == 0:
return w1

if i1 == 1:
return w2

if is_power_of_two(i1):
return w2 << (i1.bit_length() - 1)

if is_bv_value_w2:
i2: int = w2.as_long()
if i2 == 0:
return w2

if i2 == 1:
return w1

if is_power_of_two(i2):
return w1 << (i2.bit_length() - 1)

if is_bv_value_w1 or is_bv_value_w2:
return w1 * w2

return self.mk_mul(ex, w1, w2)

if op == EVM.DIV:
div_for_overflow_check = self.div_xy_y(w1, w2)
Expand All @@ -1380,7 +1419,7 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
return w1

if is_power_of_two(i2):
return LShR(w1, int(math.log(i2, 2)))
return LShR(w1, i2.bit_length() - 1)

return self.mk_div(ex, w1, w2)

Expand All @@ -1397,7 +1436,7 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
return con(0, w2.size())

if is_power_of_two(i2):
bitsize = int(math.log(i2, 2))
bitsize = i2.bit_length() - 1
return ZeroExt(w2.size() - bitsize, Extract(bitsize - 1, 0, w1))

return self.mk_mod(ex, w1, w2)
Expand Down Expand Up @@ -1449,7 +1488,7 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
if i2 <= self.options.smt_exp_by_const:
exp = w1
for _ in range(i2 - 1):
exp = exp * w1
exp = self.arith(ex, EVM.MUL, exp, w1)
return exp

return f_exp(w1, w2)
Expand Down
13 changes: 9 additions & 4 deletions tests/test_sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from halmos.sevm import (
con,
Contract,
f_mul,
f_div,
f_sdiv,
f_mod,
Expand Down Expand Up @@ -144,7 +145,7 @@ def byte_of(i, x):
[
(o(EVM.PUSH0), [], con(0)),
(o(EVM.ADD), [x, y], x + y),
(o(EVM.MUL), [x, y], x * y),
(o(EVM.MUL), [x, y], f_mul[x.size()](x, y)),
(o(EVM.SUB), [x, y], x - y),
(o(EVM.DIV), [x, y], f_div(x, y)),
(o(EVM.DIV), [con(5), con(3)], con(1)),
Expand Down Expand Up @@ -197,13 +198,17 @@ def byte_of(i, x):
(
o(EVM.MULMOD),
[x, y, con(2**3)],
ZeroExt(253, Extract(2, 0, ZeroExt(256, x) * ZeroExt(256, y))),
ZeroExt(253, Extract(2, 0, f_mul[512](ZeroExt(256, x), ZeroExt(256, y)))),
),
(
o(EVM.MULMOD),
[x, y, z],
Extract(
255, 0, f_mod[512](ZeroExt(256, x) * ZeroExt(256, y), ZeroExt(256, z))
255,
0,
f_mod[512](
f_mul[512](ZeroExt(256, x), ZeroExt(256, y)), ZeroExt(256, z)
),
),
),
(o(EVM.MULMOD), [con(10), con(10), con(8)], con(4)),
Expand All @@ -219,7 +224,7 @@ def byte_of(i, x):
(o(EVM.EXP), [x, y], f_exp(x, y)),
(o(EVM.EXP), [x, con(0)], con(1)),
(o(EVM.EXP), [x, con(1)], x),
(o(EVM.EXP), [x, con(2)], x * x),
(o(EVM.EXP), [x, con(2)], f_mul[x.size()](x, x)),
(o(EVM.SIGNEXTEND), [con(0), y], SignExt(248, Extract(7, 0, y))),
(o(EVM.SIGNEXTEND), [con(1), y], SignExt(240, Extract(15, 0, y))),
(o(EVM.SIGNEXTEND), [con(30), y], SignExt(8, Extract(247, 0, y))),
Expand Down
0