Browse Source

add prepared statement implementation

Christine Dodrill 2 years ago
parent
commit
a70b08e3d7
2 changed files with 87 additions and 0 deletions
  1. 54 0
      prepared_statement.go
  2. 33 0
      prepared_statement_test.go

+ 54 - 0
prepared_statement.go

@@ -0,0 +1,54 @@
+package gorqlite
+
+import (
+	"fmt"
+	"strings"
+)
+
+// EscapeString sql-escapes a string.
+func EscapeString(value string) string {
+	replace := [][2]string{
+		{`\`, `\\`},
+		{`\0`, `\\0`},
+		{`\n`, `\\n`},
+		{`\r`, `\\r`},
+		{`"`, `\"`},
+		{`'`, `\'`},
+	}
+
+	for _, val := range replace {
+		value = strings.Replace(value, val[0], val[1], -1)
+	}
+
+	return value
+}
+
+// PreparedStatement is a simple wrapper around fmt.Sprintf for prepared SQL
+// statements.
+type PreparedStatement struct {
+	body string
+}
+
+// NewPreparedStatement takes a sprintf syntax SQL query for later binding of
+// parameters.
+func NewPreparedStatement(body string) PreparedStatement {
+	return PreparedStatement{body: body}
+}
+
+// Bind takes arguments and SQL-escapes them, then calling fmt.Sprintf.
+func (p PreparedStatement) Bind(args ...interface{}) string {
+	var spargs []interface{}
+
+	for _, arg := range args {
+		switch arg.(type) {
+		case string:
+			spargs = append(spargs, `'`+EscapeString(arg.(string))+`'`)
+		case fmt.Stringer:
+			spargs = append(spargs, `'`+EscapeString(arg.(fmt.Stringer).String())+`'`)
+		default:
+			spargs = append(spargs, arg)
+		}
+	}
+
+	return fmt.Sprintf(p.body, spargs...)
+}

+ 33 - 0
prepared_statement_test.go

@@ -0,0 +1,33 @@
+package gorqlite
+
+import "testing"
+
+func TestPreparedStatement(t *testing.T) {
+	cases := []struct {
+		input  string
+		args   []interface{}
+		output string
+	}{
+		{
+			input:  "SELECT * FROM posts WHERE creator=%d",
+			args:   []interface{}{42},
+			output: "SELECT * FROM posts WHERE creator=42",
+		},
+		{
+			input:  "INSERT INTO posts(body) VALUES(%s)",
+			args:   []interface{}{`foo "bar" baz`},
+			output: `INSERT INTO posts(body) VALUES('foo \"bar\" baz')`,
+		},
+	}
+
+	for _, cs := range cases {
+		t.Run(cs.input, func(t *testing.T) {
+			p := NewPreparedStatement(cs.input)
+			outp := p.Bind(cs.args...)
+
+			if outp != cs.output {
+				t.Fatalf("expected output to be %s but got: %s", cs.output, outp)
+			}
+		})
+	}
+}