diff --git a/searx/plugins/calculator/__init__.py b/searx/plugins/calculator/__init__.py index aeabc9d9f..306a9d554 100644 --- a/searx/plugins/calculator/__init__.py +++ b/searx/plugins/calculator/__init__.py @@ -2,6 +2,7 @@ """Calculate mathematical expressions using ack#eval """ +import decimal import re import sys import subprocess @@ -87,9 +88,9 @@ def post_search(_request, search): val = babel.numbers.parse_decimal(val, ui_locale, numbering_system="latn") return str(val) - decimal = ui_locale.number_symbols["latn"]["decimal"] - group = ui_locale.number_symbols["latn"]["group"] - query = re.sub(f"[0-9]+[{decimal}|{group}][0-9]+[{decimal}|{group}]?[0-9]?", _decimal, query) + loc_decimal = ui_locale.number_symbols["latn"]["decimal"] + loc_group = ui_locale.number_symbols["latn"]["group"] + query = re.sub(f"[0-9]+[{loc_decimal}|{loc_group}][0-9]+[{loc_decimal}|{loc_group}]?[0-9]?", _decimal, query) # only numbers and math operators are accepted if any(str.isalpha(c) for c in query): @@ -102,6 +103,10 @@ def post_search(_request, search): result = call_calculator(query_py_formatted, 0.05) if result is None or result == "": return True - result = babel.numbers.format_decimal(result, locale=ui_locale) + if len(result) < 15: # arbitrary number, TODO : check the actual limit + try: + result = babel.numbers.format_decimal(result, locale=ui_locale) + except decimal.InvalidOperation: + pass search.result_container.answers['calculate'] = {'answer': f"{search.search_query.query} = {result}"} return True diff --git a/searx/plugins/calculator/calculator_process.py b/searx/plugins/calculator/calculator_process.py index 69d3686c7..2222c12d7 100644 --- a/searx/plugins/calculator/calculator_process.py +++ b/searx/plugins/calculator/calculator_process.py @@ -1,44 +1,129 @@ # SPDX-License-Identifier: AGPL-3.0-or-later +# pylint: disable=C0301, C0103 """Standalone script to actually calculate mathematical expressions using ast This is not a module, the SearXNG modules are not available here + +Use Decimal instead of float to keep precision """ import ast import sys import operator +from decimal import Decimal from typing import Callable +def _can_be_int(a: Decimal) -> bool: + return -1E10 < a < 1E10 + + +def _div(a: int | Decimal, b: int | Decimal) -> int | Decimal: + # If exactly divisible, return int + if isinstance(a, int) and isinstance(b, int) and a % b == 0: + return a // b + + # Otherwise, make sure to use Decimal and divide + result = Decimal(a) / Decimal(b) + + # Convert integral Decimal back to int + if _can_be_int(result) and (result % 1) == 0: + return int(result) + + # + return result + + +def _compare(ops: list[ast.cmpop], values: list[int | Decimal]) -> int: + """ + 2 < 3 becomes ops=[ast.Lt] and values=[2,3] + 2 < 3 <= 4 becomes ops=[ast.Lt, ast.LtE] and values=[2,3, 4] + """ + for op, a, b in zip(ops, values, values[1:]): + if isinstance(op, ast.Eq) and a == b: + continue + if isinstance(op, ast.NotEq) and a != b: + continue + if isinstance(op, ast.Lt) and a < b: + continue + if isinstance(op, ast.LtE) and a <= b: + continue + if isinstance(op, ast.Gt) and a > b: + continue + if isinstance(op, ast.GtE) and a >= b: + continue + + # Ignore impossible ops: + # * ast.Is + # * ast.IsNot + # * ast.In + # * ast.NotIn + + # the result is False for a and b and operation op + return 0 + # the results for all the ops are True + return 1 + + operators: dict[type, Callable] = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, - ast.Div: operator.truediv, + ast.Div: _div, + ast.FloorDiv: operator.floordiv, ast.Pow: operator.pow, ast.BitXor: operator.xor, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, ast.USub: operator.neg, + ast.RShift: operator.rshift, + ast.LShift: operator.lshift, + ast.Mod: operator.mod, + ast.Compare: _compare, } def _eval_expr(expr): """ - >>> _eval_expr('2^6') + >>> _eval_expr('2^6') # impossible since ^ are replaced by ** 4 >>> _eval_expr('2**6') 64 - >>> _eval_expr('1 + 2*3**(4^5) / (6 + -7)') - -5.0 + >>> _eval_expr('1 + 2*3**(4 & 5) / (6 + -7)') + -161 + >>> _eval_expr('1 + 2*3**(4**5) / 3') + 93347962185255010883239938546216647056352444195933501937659232930519760154316312807498422234700994305117191266357868789527181763648040214645337834245702296828547937148565645234701754987989101071392954510261670322199350731379417003085154324599143682904788096680761558781483724014647647071163698385126484050594136951860682645536131764747189062863204448353338035480186155756879682296358215593934265996371329869104231565954993221751726753474814131074299631924687318549069202765175583427589120 + >>> _eval_expr('1 + 2*3**(4**5) // 3**3') + 5563972126558721714212890776766338768980052721258014556029989298017010697979707289189006223362743515081238226077429580064247951725008500495275129928928750564369436332497933222931728064297980133974132210627893824708423062049115717594930787360617141896771675150439832137911541701236226980159503125257878544962056693545620599123008952423762623242330816289981486527930865034871082442781818842526332497380455128497137520668208195075495645610977061454910017962735135468785119221518968548 + >>> _eval_expr('1 + 2*3**(4**5) >> 1620') + 16 + >>> _eval_expr('2 < 3 < 5') + 1 + >>> _eval_expr('2 > 3') + 0 + >>> _eval_expr('5/3 + 6/3 - 5/3') + 2 + >>> _eval_expr('0.1 + 0.1 + 0.1 - 0.3') + 0 """ try: - return _eval(ast.parse(expr, mode='eval').body) + result = _eval(ast.parse(expr, mode='eval').body) + if isinstance(result, Decimal) and _can_be_int(result) and round(result, 25) == int(result): + # make sure x is x not x.0 (for example 0 instead of 0.0) + result = int(result) + return result except ZeroDivisionError: # This is undefined return "" + except OverflowError: + return "" def _eval(node): - if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + if isinstance(node, ast.Constant) and isinstance(node.value, float): + return Decimal(str(node.value)) + + if isinstance(node, ast.Constant) and isinstance(node.value, int): return node.value if isinstance(node, ast.BinOp): @@ -47,6 +132,9 @@ def _eval(node): if isinstance(node, ast.UnaryOp): return operators[type(node.op)](_eval(node.operand)) + if isinstance(node, ast.Compare): + return _compare(node.ops, [_eval(node.left)] + [_eval(c) for c in node.comparators]) + raise TypeError(node)