homework-jianmu/alert/app/expr/expr.go

793 lines
15 KiB
Go

package expr
import (
"errors"
"io"
"math"
"strconv"
"strings"
"text/scanner"
)
var (
// compile errors
ErrorExpressionSyntax = errors.New("expression syntax error")
ErrorUnrecognizedFunction = errors.New("unrecognized function")
ErrorArgumentCount = errors.New("too many/few arguments")
ErrorInvalidFloat = errors.New("invalid float")
ErrorInvalidInteger = errors.New("invalid integer")
// eval errors
ErrorUnsupportedDataType = errors.New("unsupported data type")
ErrorInvalidOperationFloat = errors.New("invalid operation for float")
ErrorInvalidOperationInteger = errors.New("invalid operation for integer")
ErrorInvalidOperationBoolean = errors.New("invalid operation for boolean")
ErrorOnlyIntegerAllowed = errors.New("only integers is allowed")
ErrorDataTypeMismatch = errors.New("data type mismatch")
)
// binary operator precedence
// 5 * / % << >> &
// 4 + - | ^
// 3 == != < <= > >=
// 2 &&
// 1 ||
const (
opOr = -(iota + 1000) // ||
opAnd // &&
opEqual // ==
opNotEqual // !=
opGTE // >=
opLTE // <=
opLeftShift // <<
opRightShift // >>
)
type lexer struct {
scan scanner.Scanner
tok rune
}
func (l *lexer) init(src io.Reader) {
l.scan.Error = func(s *scanner.Scanner, msg string) {
panic(errors.New(msg))
}
l.scan.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings
l.scan.Init(src)
l.tok = l.next()
}
func (l *lexer) next() rune {
l.tok = l.scan.Scan()
switch l.tok {
case '|':
if l.scan.Peek() == '|' {
l.tok = opOr
l.scan.Scan()
}
case '&':
if l.scan.Peek() == '&' {
l.tok = opAnd
l.scan.Scan()
}
case '=':
if l.scan.Peek() == '=' {
l.tok = opEqual
l.scan.Scan()
} else {
// TODO: error
}
case '!':
if l.scan.Peek() == '=' {
l.tok = opNotEqual
l.scan.Scan()
} else {
// TODO: error
}
case '<':
if tok := l.scan.Peek(); tok == '<' {
l.tok = opLeftShift
l.scan.Scan()
} else if tok == '=' {
l.tok = opLTE
l.scan.Scan()
}
case '>':
if tok := l.scan.Peek(); tok == '>' {
l.tok = opRightShift
l.scan.Scan()
} else if tok == '=' {
l.tok = opGTE
l.scan.Scan()
}
}
return l.tok
}
func (l *lexer) token() rune {
return l.tok
}
func (l *lexer) text() string {
switch l.tok {
case opOr:
return "||"
case opAnd:
return "&&"
case opEqual:
return "=="
case opNotEqual:
return "!="
case opLeftShift:
return "<<"
case opLTE:
return "<="
case opRightShift:
return ">>"
case opGTE:
return ">="
default:
return l.scan.TokenText()
}
}
type Expr interface {
Eval(env func(string) interface{}) interface{}
}
type unaryExpr struct {
op rune
subExpr Expr
}
func (ue *unaryExpr) Eval(env func(string) interface{}) interface{} {
val := ue.subExpr.Eval(env)
switch v := val.(type) {
case float64:
if ue.op != '-' {
panic(ErrorInvalidOperationFloat)
}
return -v
case int64:
switch ue.op {
case '-':
return -v
case '~':
return ^v
default:
panic(ErrorInvalidOperationInteger)
}
case bool:
if ue.op != '!' {
panic(ErrorInvalidOperationBoolean)
}
return !v
default:
panic(ErrorUnsupportedDataType)
}
}
type binaryExpr struct {
op rune
lhs Expr
rhs Expr
}
func (be *binaryExpr) Eval(env func(string) interface{}) interface{} {
lval := be.lhs.Eval(env)
rval := be.rhs.Eval(env)
switch be.op {
case '*':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv * rv
case int64:
return lv * float64(rv)
case bool:
panic(ErrorInvalidOperationBoolean)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) * rv
case int64:
return lv * rv
case bool:
panic(ErrorInvalidOperationBoolean)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case '/':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
if rv == 0 {
return math.Inf(int(lv))
} else {
return lv / rv
}
case int64:
if rv == 0 {
return math.Inf(int(lv))
} else {
return lv / float64(rv)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case int64:
switch rv := rval.(type) {
case float64:
if rv == 0 {
return math.Inf(int(lv))
} else {
return float64(lv) / rv
}
case int64:
if rv == 0 {
return math.Inf(int(lv))
} else {
return lv / rv
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case '%':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return math.Mod(lv, rv)
case int64:
return math.Mod(lv, float64(rv))
case bool:
panic(ErrorInvalidOperationBoolean)
}
case int64:
switch rv := rval.(type) {
case float64:
return math.Mod(float64(lv), rv)
case int64:
if rv == 0 {
return math.Inf(int(lv))
} else {
return lv % rv
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case opLeftShift:
switch lv := lval.(type) {
case int64:
switch rv := rval.(type) {
case int64:
return lv << rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
case opRightShift:
switch lv := lval.(type) {
case int64:
switch rv := rval.(type) {
case int64:
return lv >> rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
case '&':
switch lv := lval.(type) {
case int64:
switch rv := rval.(type) {
case int64:
return lv & rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
case '+':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv + rv
case int64:
return lv + float64(rv)
case bool:
panic(ErrorInvalidOperationBoolean)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) + rv
case int64:
return lv + rv
case bool:
panic(ErrorInvalidOperationBoolean)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case '-':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv - rv
case int64:
return lv - float64(rv)
case bool:
panic(ErrorInvalidOperationBoolean)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) - rv
case int64:
return lv - rv
case bool:
panic(ErrorInvalidOperationBoolean)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case '|':
switch lv := lval.(type) {
case int64:
switch rv := rval.(type) {
case int64:
return lv | rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
case '^':
switch lv := lval.(type) {
case int64:
switch rv := rval.(type) {
case int64:
return lv ^ rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
case opEqual:
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv == rv
case int64:
return lv == float64(rv)
case bool:
panic(ErrorDataTypeMismatch)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) == rv
case int64:
return lv == rv
case bool:
panic(ErrorDataTypeMismatch)
}
case bool:
switch rv := rval.(type) {
case float64:
case int64:
case bool:
return lv == rv
}
}
case opNotEqual:
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv != rv
case int64:
return lv != float64(rv)
case bool:
panic(ErrorDataTypeMismatch)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) != rv
case int64:
return lv != rv
case bool:
panic(ErrorDataTypeMismatch)
}
case bool:
switch rv := rval.(type) {
case float64:
case int64:
case bool:
return lv != rv
}
}
case '<':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv < rv
case int64:
return lv < float64(rv)
case bool:
panic(ErrorDataTypeMismatch)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) < rv
case int64:
return lv < rv
case bool:
panic(ErrorDataTypeMismatch)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case opLTE:
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv <= rv
case int64:
return lv <= float64(rv)
case bool:
panic(ErrorDataTypeMismatch)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) <= rv
case int64:
return lv <= rv
case bool:
panic(ErrorDataTypeMismatch)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case '>':
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv > rv
case int64:
return lv > float64(rv)
case bool:
panic(ErrorDataTypeMismatch)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) > rv
case int64:
return lv > rv
case bool:
panic(ErrorDataTypeMismatch)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case opGTE:
switch lv := lval.(type) {
case float64:
switch rv := rval.(type) {
case float64:
return lv >= rv
case int64:
return lv >= float64(rv)
case bool:
panic(ErrorDataTypeMismatch)
}
case int64:
switch rv := rval.(type) {
case float64:
return float64(lv) >= rv
case int64:
return lv >= rv
case bool:
panic(ErrorDataTypeMismatch)
}
case bool:
panic(ErrorInvalidOperationBoolean)
}
case opAnd:
switch lv := lval.(type) {
case bool:
switch rv := rval.(type) {
case bool:
return lv && rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
case opOr:
switch lv := lval.(type) {
case bool:
switch rv := rval.(type) {
case bool:
return lv || rv
default:
panic(ErrorOnlyIntegerAllowed)
}
default:
panic(ErrorOnlyIntegerAllowed)
}
}
return nil
}
type funcExpr struct {
name string
args []Expr
}
func (fe *funcExpr) Eval(env func(string) interface{}) interface{} {
argv := make([]interface{}, 0, len(fe.args))
for _, arg := range fe.args {
argv = append(argv, arg.Eval(env))
}
return funcs[fe.name].call(argv)
}
type floatExpr struct {
val float64
}
func (fe *floatExpr) Eval(env func(string) interface{}) interface{} {
return fe.val
}
type intExpr struct {
val int64
}
func (ie *intExpr) Eval(env func(string) interface{}) interface{} {
return ie.val
}
type boolExpr struct {
val bool
}
func (be *boolExpr) Eval(env func(string) interface{}) interface{} {
return be.val
}
type stringExpr struct {
val string
}
func (se *stringExpr) Eval(env func(string) interface{}) interface{} {
return se.val
}
type varExpr struct {
name string
}
func (ve *varExpr) Eval(env func(string) interface{}) interface{} {
return env(ve.name)
}
func Compile(src string) (expr Expr, err error) {
defer func() {
switch x := recover().(type) {
case nil:
case error:
err = x
default:
}
}()
lexer := lexer{}
lexer.init(strings.NewReader(src))
expr = parseBinary(&lexer, 0)
if lexer.token() != scanner.EOF {
panic(ErrorExpressionSyntax)
}
return expr, nil
}
func precedence(op rune) int {
switch op {
case opOr:
return 1
case opAnd:
return 2
case opEqual, opNotEqual, '<', '>', opGTE, opLTE:
return 3
case '+', '-', '|', '^':
return 4
case '*', '/', '%', opLeftShift, opRightShift, '&':
return 5
}
return 0
}
// binary = unary ('+' binary)*
func parseBinary(lexer *lexer, lastPrec int) Expr {
lhs := parseUnary(lexer)
for {
op := lexer.token()
prec := precedence(op)
if prec <= lastPrec {
break
}
lexer.next() // consume operator
rhs := parseBinary(lexer, prec)
lhs = &binaryExpr{op: op, lhs: lhs, rhs: rhs}
}
return lhs
}
// unary = '+|-' expr | primary
func parseUnary(lexer *lexer) Expr {
flag := false
for tok := lexer.token(); ; tok = lexer.next() {
if tok == '-' {
flag = !flag
} else if tok != '+' {
break
}
}
if flag {
return &unaryExpr{op: '-', subExpr: parsePrimary(lexer)}
}
flag = false
for tok := lexer.token(); tok == '!'; tok = lexer.next() {
flag = !flag
}
if flag {
return &unaryExpr{op: '!', subExpr: parsePrimary(lexer)}
}
flag = false
for tok := lexer.token(); tok == '~'; tok = lexer.next() {
flag = !flag
}
if flag {
return &unaryExpr{op: '~', subExpr: parsePrimary(lexer)}
}
return parsePrimary(lexer)
}
// primary = id
// | id '(' expr ',' ... ',' expr ')'
// | num
// | '(' expr ')'
func parsePrimary(lexer *lexer) Expr {
switch lexer.token() {
case '+', '-', '!', '~':
return parseUnary(lexer)
case '(':
lexer.next() // consume '('
node := parseBinary(lexer, 0)
if lexer.token() != ')' {
panic(ErrorExpressionSyntax)
}
lexer.next() // consume ')'
return node
case scanner.Ident:
id := strings.ToLower(lexer.text())
if lexer.next() != '(' {
if id == "true" {
return &boolExpr{val: true}
} else if id == "false" {
return &boolExpr{val: false}
} else {
return &varExpr{name: id}
}
}
node := funcExpr{name: id}
for lexer.next() != ')' {
arg := parseBinary(lexer, 0)
node.args = append(node.args, arg)
if lexer.token() != ',' {
break
}
}
if lexer.token() != ')' {
panic(ErrorExpressionSyntax)
}
if fn, ok := funcs[id]; !ok {
panic(ErrorUnrecognizedFunction)
} else if fn.minArgs >= 0 && len(node.args) < fn.minArgs {
panic(ErrorArgumentCount)
} else if fn.maxArgs >= 0 && len(node.args) > fn.maxArgs {
panic(ErrorArgumentCount)
}
lexer.next() // consume it
return &node
case scanner.Int:
val, e := strconv.ParseInt(lexer.text(), 0, 64)
if e != nil {
panic(ErrorInvalidFloat)
}
lexer.next()
return &intExpr{val: val}
case scanner.Float:
val, e := strconv.ParseFloat(lexer.text(), 0)
if e != nil {
panic(ErrorInvalidInteger)
}
lexer.next()
return &floatExpr{val: val}
case scanner.String:
panic(errors.New("strings are not allowed in expression at present"))
val := lexer.text()
lexer.next()
return &stringExpr{val: val}
default:
panic(ErrorExpressionSyntax)
}
}