A visual explanation of the PGM Index

Few months back, a HN post about learned indexes caught my attention. 83x less space with the same performance as a B-tree? What sort of magic trick is this?! And why isn’t everybody else using it if it’s so good?

I decided to spend some time reading the paper to really understand what’s going on. Now reading a scientific paper is a daunting endeavour as most of it is written in small text, decorated with magical Greek symbols, with little to no diagrams for us lay people. But this is really a beautiful piece of data structure which deserves to be used more widely in other systems, rather than just languish in academic circles.

So if you have been putting off from reading it, this post is my attempt at a simplified explanation of what is in the paper. No mathematical symbols. No equations. Just the basic idea of how it works.

What is an index

Let’s consider an array of sorted integers. We want to calculate the predecessor of x; i.e. given x, find out the largest integer in the array lesser than or equal to x. For example, if our array is {2,8,10,18,20}, predecessor(15) would be 10, and predecessor(20) would be 20. An extension of this would be the range(x, y) function, which would give us the set of integers lying within the range of x and y. Any data structure satisfying these requirements can be essentially considered as an index.

If we go through existing categories of data structures:

  • Hash-based ones can only be used to “lookup” the position of a key.
  • Bitmap-based indexes can be expensive to store, maintain and decompress.
  • Trie-based indexes are mostly pointer based, and therefore take up space proportional to the data set.

Which brings us to B-tree based indexes and its variants being the go-to choice for such operations, and is widely used in all databases.

What is a learned index

In a learned index, the key idea is that indexes can be trained to “learn” this predecessor function. Naturally, the first thing that comes to mind is Machine Learning. And indeed, some implementation have used ML to learn this mapping of key to array position within an error approximation.

But unlike those, the PGM index is a fully dynamic, learned index, without any machine learning, that provides a maximum error tolerance and takes smaller space.

PGM Index

PGM means “piece-wise geometric model”. It attempts to create line segments that fit the key distribution in a cartesian plane. We call this a linear approximation. And it’s called “piece-wise” because a single line segment may not be able to express the entire set of keys, within a given error margin. Essentially, it’s a “piece-wise linear approximation” model. If all that sounds complicated, it really isn’t. Let’s take an example.

Consider the input data of {2,8,10,18,20} that we had earlier. Let’s plot these in a graph with the keys in the x-axis and array positions in the y-axis.

basic graph

Now, we can see that we can express a set of points that are more or less linear, as a single line segment. Let’s draw a line for the points {2,0}, {8,1}, {10,2}.

line

So any value lying within [2,10] can be mapped with this line. For example, let’s try to find predecessor(9). We take 9 to be the value of x and plot the value of y.

point

Once we get y, the algorithm guarantees that the actual value will lie within a range of {-e, +e}. And if we just do a binary search within a space of 2e + 1, we get our desired position.

That’s all there is to it. Instead of storing the keys, we are just storing the slopes and intercepts, which is completely unrelated to the size of the data, but more dependent on the shape of it. The more random the data is, we need more line segments to express it. And on the other extreme, a set like {1,2,3,4} can be expressed with a single line with zero error.

But this leads to another problem. Once we have a set of line segments, each line segment only covers a portion of our entire key space. Given any random key, how do we know which line segment to use? It’s simple. We run the same algorithm again!

  • Construction

Let’s run through an example and see how do we build up the entire index. Assume our data set is {2, 12, 15, 18, 23, 24, 29, 31, 34, 36, 38, 48}. And error is e.

The algorithm to construct the index is as follows:

  1. We take each point from our set of {k, pos(k)}, and incrementally construct a convex hull from those points.
  2. At every iteration, we construct a bounding rectangle from the convex hull. This is a well-known computational geometry problem, of which there are several solutions. One of them is called Rotating Callipers.
  3. As long as the height of the bounding rectangle is not more than 2e, we keep adding points to the hull.
  4. When the height exceeds 2e, we stop our process, and construct a line segment joining the midpoints of the two arms of the rectangle.
  5. We store the first point in the set, the slope and intercept of the line segment, and repeat the whole process again.

At the end, we will get an array of tuples of (point, slope, intercept).

demo first_pass

Now let’s wipe all the remaining points except the ones from the tuples and run the same algorithm again.

demo second_pass

We see that each time, we get an array of decreasing size until we just have a root element. The in-memory representation becomes something like this:

In-memory representation

  • Search

The algorithm to search for a value is as follows:

  1. We start with the root tuple of the index and compute the result of y = k * sl + ic, for an input value of k.
  2. A lower bound lo is calculated to be y-e and a similar upper bound hi as y+e.
  3. We search in the next array in A[lo:hi] to find the rightmost element such that A[i].key <= k
  4. Once an element is found, the whole algorithm is repeated again to calculate y of that node, and search in the next array.
  5. This continues until we reach the original array, and we find our target position.

Search path

The paper proves that the number of tuples (m) for the last level will always be less than n/2e. Since this also holds true for the upper levels, it means that a PGM index cannot be worse than a 2e way B-tree. Because if at every level, we do a binary search within 2e +1, our worst case time complexity is O(log(m) + log(e)). However, in practice, a PGM index is seen to be much faster than a B-tree because m is usually far lower than n.

  • Addition/Removal

Insertions and deletions in a PGM index are slightly tricker compared to traditional indexes. That is because a single tuple could index a variable and potentially large subset of data, which makes the classic B-tree node split and merge algorithms inapplicable. The paper proposes two approaches to handle updates, one customized for append-only data structures like time-series data. Another for general random update scenarios.

In an append-only scenario, the key is first added to the last tuple. If this does not exceeed e threshold, the process stops. If it does exceed, we create a new tuple with the key, and continue the process with the last tuple of the upper layer. This continues until we find a layer where adding the key remains within the threshold. If this continues till the root node, it gets split into two nodes, and a new root node gets created above that.

For inserts that happen in arbitrary positions, it gets slightly more complicated. In this case, we have to maintain multiple PGM indexes built over sets of keys. These sets are either empty or have size 20, 21 .. 2b where b = O(log(n)). Now each insert of a key k finds the first empty set, and builds a new PGM index from all the previous sets including the key k, and then the previous sets are emptied. Let’s take an example. Assume we are starting from scratch and we want to insert 3,8,1,6,9 in the index.

  1. Since everything is empty, we find our first set S0 and insert 3. So our PGM looks like

     S0 = [3]
    
  2. Now the next empty set is S1, because S0 is non-empty. So we take 3 from the last set, and add 8 to S1. S0 is emptied.

     S0 = []
     S1 = [3,8]
    
  3. Our next key is 1. The first empty set is S0. We just add 1 to S0 and move on.

     S0 = [1]
     S1 = [3,8]
    
  4. Both S0 and S1 are non-empty now. So we move to S2, and empty S0 and S1.

     S0 = []
     S1 = []
     S2 = [1,3,6,8]
    
  5. Again, the first empty set is S0. So 9 goes in it.

     S0 = [9]
     S1 = []
     S2 = [1,3,6,8]
    

The deletion of a key d is handled similar to an insert by adding a special tombstone value that indicates the logical removal of d.

Conclusion

And that was a very basic overview of the PGM index. There are further variants of this, fully described in detail in the paper. The successor from this is a Compressed PGM index which compresses the tuples. Then we have a Distribution-aware PGM index which adapts itself not only to the key distribution, but also to the distribution of queries. This is desirable in cases where it’s important to have more frequent queries respond faster than rare ones. Finally, we have a Multi-criteria PGM index that can be tuned to either optimize for time or optimize for space.

I have also created a port of the algorithm in Go here to understand the algorithm better. It’s just a prototype, and suffers from minor approximation issues. For a production-ready library, refer to the author’s C++ implementation here.

Lastly, I would like to thank Giorgio for taking the time to explain some aspects of the paper which I found hard to follow. His guidance has been a indispensable part in my understanding of the paper.

Links:

Setting up gopls with Sublime Text

If you are a Sublime Text user, and looking to set up gopls integration with it, you have arrived at the right place. The primary documentation for gopls assumes you are using VSCode; and the rest are using GoLand, which leaves us Sublime Text users in a tight spot. This post attempts to fill that gap.

The official documentation here just mentions how to install gopls, which is barely enough. But for the sake of completeness, I will go through the entire set of steps.

Installation

  1. Install gopls on your machine.
    • Go to any temp directory and run go get golang.org/x/tools/gopls@latest.
    • If you see the error go: cannot use path@version syntax in GOPATH mode, then run GO111MODULE=on go get golang.org/x/tools/gopls@latest
    • Check that the gopls binary got installed by running which gopls.
  2. Open the Command Pallete (Shift+Ctrl+p). Select “Install Package”
  3. Select “LSP”.
  4. Open the Command Pallete again.
  5. Select “LSP: Enable Language Server Globally”.
  6. Select “gopls”.

This completes the installation part, which is half the battle. Next up, we need to configure gopls.

Configuration

  1. Navigate to Preferences > Package Settings > LSP > Settings. In the User settings section, paste this:

     {
         "clients":
         {
             "gopls":
             {
                 "command": ["/home/agniva/go/bin/gopls"],
                 "enabled": true,
                 "env": {
                     "PATH": "/home/agniva/go/bin:/usr/local/go/bin"
                 },
                 "scopes":["source.go"],
                 "syntaxes": [
                     "Packages/Go/Go.sublime-syntax",
                 ],
                 "settings": {
                     "gopls.usePlaceholders": true,
                     "gopls.completeUnimported": true,
                 },
                 "languageId": "go"
             }
         },
         "only_show_lsp_completions": true,
         "show_references_in_quick_panel": true,
         "log_debug": true,
         "log_stderr": true
     }
    

    Adjust the file paths accordingly.

  2. There are several things to note here. Depending on your shell settings, you may need to pass the full file path. Otherwise, you might see the error “Could not start gopls. I/O timeout.”

  3. Any custom settings need to be placed under the settings key. And the key names need to be prefixed with “gopls.”. For the full list of settings, check here.

  4. Navigate to Preferences > Package Settings > LSP > Key Bindings. You will see that a lot of commands have keymap as “UNBOUND”. Set them as per your old shortcuts.

  5. Open the Command Pallete. Select “LSP: Restart Servers”.

  6. Enjoy a working setup of gopls.

Hopefully this was helpful. As always, please feel free to suggest any improvements in the comments.

Generating WebAssembly CPU Profiles in Go

Go has had WebAssembly (wasm) support for a while now, but the tooling is still in it’s nascent stages. It is straightforward to build a wasm module from Go code, but running tests in a browser is still cumbersome, as it requires some HTML and JS glue to work, and generating a CPU profile isn’t even possible since wasm does not have thread support (yet).

I wrote a tool wasmbrowsertest which automates the running of tests in a browser and adds the ability to take a CPU profile. The idea is to compile the test into a binary and spin up a web server to serve the required HTML and JS to run the test. Then we use the Chrome Devtools Protocol to start a headless browser and load the web page. Finally, the console logs are captured and relayed to the command line.

This takes care of running the tests. But this post is about how to generate and analyze CPU profiles in WebAssembly natively, using the Go toolchain. Before I proceed, I should clarify that the following was done in a Chromium-based browser since it needs to work with the Chrome Devtools Protocol. The footnotes section explains why Selenium wasn’t used.

The problem

The developer tools in Google Chrome can take CPU Profiles of any webpage. This allows us to get a profile while the wasm test is running in the browser. But unfortunately, this profile has its own format, and the Go toolchain works with the pprof format. To make this work natively in Go, we need to convert the profile from this devtools format to the pprof format.

What is a profile

At a very basic level, a profile is just a set of samples, where each sample contains a stack frame. The difference in various profile formats lie in how all of it is represented on disk. Let us look into how this is represented in the devtools format, and then we will go over how to convert it to the pprof format.

CDP Profile

A CDP (Chrome Devtools Protocol) profile is represented in a json format with the following top-level keys:

{
	"startTime": ..., // Start time of the profile in us
	"endTime": ..., // End time of the profile in us.
	"nodes": [{...}, {...}, ...],
	"samples": [1,2,1,1],
	"timeDeltas": [...,...], // Time interval between consecutive samples in us.
}

nodes is a list of profile nodes. A node is a single function call site containing information about the function name, line number, and the script it was called from. It also has it’s own unique ID. And a list of child IDs, which are IDs of the respective child nodes.

samples represents the samples taken during a profile. It is a list of node IDs, where each ID points to the leaf node of a stack frame.

To represent it in a diagram:

cdp diagram

For node 12- 9,10 and 11 are its child IDs.

From our samples array above, we have 1,2,1,1 as samples. So, in terms of a list of stack frames, it becomes

stack frames

PProf Profile

A pprof profile is a proto file which is serialized and stored on disk in a gzip-compressed format. Now, a profile for code running natively on a machine will contain extra information regarding the memory address space locations and other stuff. But since our chrome profile runs inside a browser, we do not have access to such low-level details, and hence our converted profile will not have all the features of a proper pprof profile.

At a high level, a pprof profile has:

type Profile struct {
	Sample            []*Sample
	Location          []*Location
	Function          []*Function

	TimeNanos     int64
	DurationNanos int64
}

type Sample struct {
	Location []*Location
}

type Location struct {
	ID       uint64
	Line     []Line
}

type Line struct {
	Function *Function
	Line     int64
}

type Function struct {
	ID         uint64
	Name       string
	Filename   string
}

Essentially, a profile contains a list of samples. And each sample contains a list of locations. Each location contains a function object along with it’s line number (for simplicity’s sake, we will consider each location to have a single line). Lastly, a function object just has the function name and the file name from where it was called.

pprof diagram

It is a flat representation where the hierarchy is maintained by pointers. So, to construct such a profile, we need to create it from the bottom up- i.e. first we need to construct the list of functions, then locations and then samples.

Converting Devtools to Pprof

To quickly recap what we are trying to achieve here: we have a devtools profile in a json format, and we want to convert it to a pprof format like the struct mentioned above. The TimeNanos and DurationNanos are simple and can be directly set. To create the Function and Location slices, we just need to iterate through the nodes array. As a quick reminder: a node is a single function call site containing information about the function name, line number, and the script it was called from, along with it’s own unique ID.

Note that the node ID is for the node and does not guarantee that different nodes will have different callframes. So we need to create a unique key that we can use to uniquely identify functions. Let that key be - FunctionName + strconv.Itoa(int(LineNumber)) + strconv.Itoa(int(ColumnNumber)) (we get these fields from the callframe object). And for every new instance of a Function, we will use a monotonically increasing uint64 as the function ID. For the location ID, we can directly use the node ID.

So with that, we can get the slice of Functions and since we have the line number too inside the callframe, we can create the Location slice also.

But before we construct the Sample information, we need to create the stack frame of each sample. That information is not directly present in the profile, but we can generate it.

We have the list of children of each node. From this, we can construct the inverse relation where we know what is the parent of each node. Let’s have a map from a nodeID to a struct, containing the pointer to a node and also its parent. Then we can iterate the samples list again and for each child of a node, we point the child to the current node. This will complete all the connections where each node points to its parent.

This is a simplified code snippet which shows what is being done.

// locMeta is a wrapper around profile.Location with an extra
// pointer towards its parent node.
type locMeta struct {
	loc    *profile.Location
	parent *profile.Location
}

// We need to map the nodeID to a struct pointing to the node
// and its parent.
locMap := make(map[int64]locMeta)
// A map to uniquely identify a Function.
fnMap := make(map[string]*profile.Function)
// A monotonically increasing function ID.
// We bump this everytime we see a new function.
var fnID uint64 = 1

for _, n := range prof.Nodes {
	cf := n.CallFrame
	fnKey := cf.FunctionName + strconv.Itoa(int(cf.LineNumber)) + strconv.Itoa(int(cf.ColumnNumber))
	pFn, exists := fnMap[fnKey]
	if !exists {
		// Add to Function slice.
		pFn = &profile.Function{
			ID:         fnID,
			Name:       cf.FunctionName,
			SystemName: cf.FunctionName,
			Filename:   cf.URL,
		}
		pProf.Function = append(pProf.Function, pFn)

		fnID++

		// Add it to map
		fnMap[fnKey] = pFn
	}

	// Add to Location slice.
	loc := &profile.Location{
		ID: uint64(n.ID),
		Line: []profile.Line{
			Function: pFn,
			Line:     cf.LineNumber,
		},
	}
	pProf.Function = append(pProf.Function, loc)

	// Populating the loc field of the locMap
	locMap[n.ID] = locMeta{loc: loc}
}

// We need to iterate once more to build the parent-child chain.
for _, n := range prof.Nodes {
	parent := locMap[n.ID]
	// Visit each child node, get the node pointer from the map,
	// and set the parent pointer to the parent node.
	for _, childID := range n.Children {
		child := locMap[childID]
		child.parent = parent.loc
		locMap[childID] = child
	}
}

Once we have that, we can just iterate over samples array and consult our locMap to get the leaf node and from there walk up the chain to get the entire call stack.

Finally, we now have our Samples, Location and Function slices along with other minor details which I have omitted. Using this, once we have the profile, we can simply run go tool pprof sample.prof and look at the call graph or the flame graph.

Here is an example of a profile taken for the encoding/json package’s EncoderEncode benchmark.

The SVG call graph

The Flame graph- flame graph

Please feel free to check the github repo to see the full source code.

Footnotes

  • The initial idea was to use a Selenium API and drive any browser to run the tests. But unfortunately, geckodriver does not support the ability to capture console logs - https://github.com/mozilla/geckodriver/issues/284. Hence, the shift to use the ChromeDP protocol circumvents the need to have any external driver binary and just have a browser installed in the machine.
  • Unfortunately, all of this will be moot once WebAssembly has thread support (which is already in an experimental phase). Nevertheless, I hope this post shed some light into how profiles are generated !
  • A big shoutout to Alexei Filippov from the Chrome Devtools team to help me understand some aspects of a CDP profile.

Taking the new Go error values proposal for a spin

UPDATE July 1, 2019: The proposal has changed since the blog post was written. Stack traces have been omitted. Now, only the Unwrap, Is and As functions are kept. Also the %w format verb can be used to wrap errors. More information here.

Original article follows:

There is a new error values proposal for the Go programming language which enhances the errors and fmt packages, adding ability to wrap errors and embed stack traces, amongst other changes. The changes are now available in the master branch and undergoing the feedback process.

I wanted to give it a spin and see how does it address some of the issues I’ve had while using errors. For posterity, I am using the master branch at go version devel +e96c4ace9c Mon Mar 18 10:50:57 2019 +0530 linux/amd64.

Stack Traces

Adding context to an error is good. But it does not add any value to the message when I need to find where the error is coming from and fix it. It does not matter if the message is error getting users: no rows found or no rows found, if I don’t know the line number of the error’s origin. And in a big codebase, it is an extremely uphill task to map the error message to the error origin. All I can do is grep for the error message and pray that the same message is not used multiple times.

Naturally, I was ecstatic to see that errors can capture stack traces now. Let’s look at an existing example which exemplifies the problem I mentioned above and then see how to add stack traces to the errors.

package main

import (
	// ...
)

func main() {
	// getting the db handle is omitted for brevity
	err := insert(db)
	if err != nil {
		log.Printf("%+v\n", err)
	}
}

func insert(db *sql.DB) error {
	tx, err := db.Begin()
	if err != nil {
		return err
	}
	var id int
	err = tx.QueryRow(`INSERT INTO tablename (name) VALUES ($1) RETURNING id`, "agniva").Scan(&id)
	if err != nil {
		tx.Rollback()
		return err
	}

	_, err = tx.Exec(`INSERT INTOtablename (name) VALUES ($1)`, "ayan") // This will fail. But how do we know just from the error ?
	if err != nil {
		tx.Rollback()
		return err
	}
	return tx.Commit()
}

The example is a bit contrived. But the idea here is that if any of the SQL queries fail, there is no way of knowing which one is it.

2019/03/20 12:18:40 pq: syntax error at or near "INTOtablename"

So we add some context to it -

err = tx.QueryRow(`INSERT INTO tablename (name) VALUES ($1) RETURNING id`, "agniva").Scan(&id)
if err != nil {
	tx.Rollback()
	return fmt.Errorf("insert and return: %v", err)
}

_, err = tx.Exec(`INSERT INTOtablename (name) VALUES ($1)`, "ayan")
if err != nil {
	tx.Rollback()
	return fmt.Errorf("only insert: %v", err)
}
2019/03/20 12:19:38 only insert: pq: syntax error at or near "INTOtablename"

But that’s still not enough. I will naturally forget in which file and in which function I wrote that query; leading me to grep for “only insert”. I just want that line number :tired_face:

But all that’s changing. With the new design, function, file and line information are added to all errors returned by errors.New and fmt.Errorf. And this stack information is displayed when the error is printed by “%+v”.

If the same code is executed using Go at tip:

2019/03/20 12:20:10 only insert:
    main.doDB
        /home/agniva/play/go/src/main.go:71
  - pq: syntax error at or near "INTOtablename"

But there are some catches here. Notice how we gave a : and then added a space before writing %v. That makes the returned error have the FormatError method which allows the error to be formatted cleanly. Also, the last argument must be an error for this to happen. If we remove the :, then we just get:

2019/03/20 23:28:38 only insert pq: syntax error at or near "INTOtablename":
    main.doDB
        /home/agniva/play/go/src/main.go:72

which is just the error message with the stack trace.

This feels very magical and surprising. And unsurprisingly, there has been considerable debate on this at https://github.com/golang/go/issues/29934. In the words of @rsc here -

It’s true that recognizing : %v is a bit magical. This is a good point to raise. If we were doing it from scratch, we would not do that. But an explicit goal here is to make as many existing programs automatically start working better, just like we did in the monotonic time changes. Sometimes that constrains us more than starting on a blank slate. On balance we believe that the automatic update is a big win and worth the magic.

But now that I have the line numbers, I don’t really need to add extra context. I can just write:

err = tx.QueryRow(`INSERT INTO tablename (name) VALUES ($1) RETURNING id`, "agniva").Scan(&id)
if err != nil {
	tx.Rollback()
	return fmt.Errorf(": %v", err)
}

_, err = tx.Exec(`INSERT INTOtablename (name) VALUES ($1)`, "ayan")
if err != nil {
	tx.Rollback()
	return fmt.Errorf(": %v", err)
}
2019/03/20 13:08:15 main.doDB
        /home/agniva/play/go/src/main.go:71
  - pq: syntax error at or near "INTOtablename"

Personally, I feel this is pretty clumsy, and having to write “: %v” every time is quite cumbersome. I still think that adding a new function is cleaner and much more readable. If you read errors.WithFrame(err) instead of fmt.Errorf(": %v", err), it is immediately clear what the code is trying to achieve.

With that said, the package does expose a Frame type which allows you to create your own errors with stack information. So it is quite easy to write a helper function which does the equivalent of fmt.Errorf(": %v", err).

A crude implementation can be something like:

func withFrame(err error) error {
	return errFrame{err, errors.Caller(1)}
}

type errFrame struct {
	err error
	f   errors.Frame
}

func (ef errFrame) Error() string {
	return ef.err.Error()
}

func (ef errFrame) FormatError(p errors.Printer) (next error) {
	ef.f.Format(p)
	return ef.err
}

And then just call withFrame instead of fmt.Errorf(": %v", err):

err = tx.QueryRow(`INSERT INTO tablename (name) VALUES ($1) RETURNING id`, "agniva").Scan(&id)
if err != nil {
	tx.Rollback()
	return withFrame(err)
}

_, err = tx.Exec(`INSERT INTOtablename (name) VALUES ($1)`, "ayan")
if err != nil {
	tx.Rollback()
	return withFrame(err)
}

This generates the same output as before.

Wrapping Errors

Alright, it’s great that we are finally able to capture stack traces. But there is more to the proposal than just that. We also have the ability now to embed an error inside another error without losing any of the type information of the original error.

For example, in our previous example, we used fmt.Errorf(": %v", err) to capture the line number. But now we have lost the information that err was of type pq.Error or it could even have been sql.ErrNoRows which the caller function could have checked and taken appropriate actions. To be able to wrap the error, we need to use a new formatting verb w. Here is what it looks like:

err = tx.QueryRow(`INSERT INTO tablename (name) VALUES ($1) RETURNING id`, "agniva").Scan(&id)
if err != nil {
	tx.Rollback()
	return fmt.Errorf(": %w", err)
}

_, err = tx.Exec(`INSERT INTOtablename (name) VALUES ($1)`, "ayan")
if err != nil {
	tx.Rollback()
	return fmt.Errorf(": %w", err)
}

Now, the position information is captured as well as the original error is wrapped into the new error. This allows us to inspect the returned error and perform checks on it. The proposal gives us 2 functions to help with that- errors.Is and errors.As.

func As(err error, target interface{}) bool

As finds the first error in err’s chain that matches the type to which target points, and if so, sets the target to its value and returns true. An error matches a type if it is assignable to the target type, or if it has a method As(interface{}) bool such that As(target) returns true.

So in our case, to check whether err is of type pq.Error:

func main() {
	// getting the db handle is omitted for brevity
	err := insert(db)
	if err != nil {
		log.Printf("%+v\n", err)
	}
	pqe := &pq.Error{}
	if errors.As(err, &pqe) {
		log.Println("Yep, a pq.Error")
	}
}
2019/03/20 14:28:33 main.doDB
        /home/agniva/play/go/src/main.go:72
  - pq: syntax error at or near "INTOtablename"
2019/03/20 14:28:33 Yep, a pq.Error

func Is(err, target error) bool

Is reports whether any error in err’s chain matches target. An error is considered to match a target if it is equal to that target or if it implements a method Is(error) bool such that Is(target) returns true.

Continuing with our previous example:

func main() {
	// getting the db handle is omitted for brevity
	err := insert(db)
	if err != nil {
		log.Printf("%+v\n", err)
	}
	pqe := &pq.Error{}
	if errors.As(err, &pqe) {
		log.Println("Yep, a pq.Error")
	}
	if errors.Is(err, sql.ErrNoRows) {
		log.Println("Yep, a sql.ErrNoRows")
	}
}
2019/03/20 14:29:03 main.doDB
        /home/agniva/play/go/src/main.go:72
  - pq: syntax error at or near "INTOtablename"
2019/03/20 14:29:03 Yep, a pq.Error

ErrNoRows did not match, which is what we expect.

Custom error types can also be wrapped and checked in a similar manner. But to be able to unwrap the error, the type needs to satisfy the Wrapper interface, and have a Unwrap method which returns the inner error. Let’s say we want to return ErrNoUser if a sql.ErrNoRows is returned. We can do:

type ErrNoUser struct {
	err error
}

func (e ErrNoUser) Error() string {
	return e.err.Error()
}

// Unwrap satisfies the Wrapper interface.
func (e ErrNoUser) Unwrap() error {
	return e.err
}

func main() {
	// getting the db handle is omitted for brevity
	err := getUser(db)
	if err != nil {
		log.Printf("%+v\n", err)
	}
	ff := ErrNoUser{}
	if errors.As(err, &ff) {
		log.Println("Yep, ErrNoUser")
	}
}

func getUser(db *sql.DB) error {
	var id int
	err := db.QueryRow(`SELECT id from tablename WHERE name=$1`, "notexist").Scan(&id)
	if err == sql.ErrNoRows {
		return fmt.Errorf(": %w", ErrNoUser{err: err})
	}
	return err
}
2019/03/21 10:56:16 main.getUser
        /home/agniva/play/go/src/main.go:100
  - sql: no rows in result set
2019/03/21 10:56:16 Yep, ErrNoUser

This is mostly my take on how to integrate the new changes into a codebase. But it is in no way an exhaustive tutorial on it. For a deeper look, please feel free to read the proposal. There is also an FAQ which touches on some useful topics.

TLDR

There is a new proposal which makes some changes to the errors and fmt packages. The highlights of which are:

  • All errors returned by errors.New and fmt.Errorf now capture stack information.
  • The stack can be printed by using %+v which is the “detail mode”.
  • For fmt.Errorf, if the last argument is an error and the format string ends with : %s, : %v or : %w, the returned error will have the FormatError method. In case of %w, the error will also be wrapped and have the Unwrap method.
  • There are 2 new convenience functions errors.Is and errors.As which allow for error inspection.

As always, please feel free to point out any errors or suggestions in the comments. Thanks for reading !

How to write a Vet analyzer pass

The Go toolchain has the vet command which can be used to to perform static checks on a codebase. But a significant problem of vet was that it was not extensible. vet was structured as a monolithic executable with a fixed suite of checkers. To overcome this, the ecosystem started developing its own tools like staticcheck and go-critic. The problem with this is that every tool has its own way to load and parse the source code. Hence, a checker written for one tool would require extensive effort to be able to run on a different driver.

During the 1.12 release cycle, a new API for static code analysis was developed: the golang.org/x/tools/go/analysis package. This creates a standard API for writing Go static analyzers, which allows them to be easily shared with the rest of the ecosystem in a plug-and-play model.

In this post, we will see how to go about writing an analyzer using this new API.

Background

SQL queries are always evaluated at runtime. As a result, if you make a syntax error in a SQL query, there is no way to catch that until you run the code or write a test for it. There was this peculiar pattern in particular, that was always tripping me up.

Let’s say I have a SQL query like:

db.Exec("insert into table (c1, c2, c3, c4) values ($1, $2, $3, $4)", p1, p2, p3, p4)

It’s the middle of the night and I need to add a new column. I quickly change the query to:

db.Exec("insert into table (c1, c2, c3, c4, c5) values ($1, $2, $3, $4)", p1, p2, p3, p4, p5).

It seems like things are fine, but I have just missed a $5. This bugged me so much that I wanted to write a vet analyzer for this to detect patterns like these and flag them.

There are other semantic checks we can apply like matching the number of positional args with the number of params passed and so on. But in this post, we will just focus on the most basic check of verifying whether a sql query is syntactically correct or not.

Layout of an analyzer

All analyzers usually expose a global variable Analyzer of type analysis.Analyzer. It is this variable which is imported by driver packages.

Let us see what it looks like:

var Analyzer = &analysis.Analyzer{
	Name:             "sqlargs",                                 // name of the analyzer
	Doc:              "check sql query strings for correctness", // documentation
	Run:              run,                                       // perform your analysis here
	Requires:         []*analysis.Analyzer{inspect.Analyzer},    // a set of analyzers which must run before the current one.
	RunDespiteErrors: true,
}

Most of the fields are self-explanatory. The actual analysis is performed by run: a function which takes an analysis.Pass as an argument. The pass variable provides information to the run function to perform its tasks and optionally pass on information to other analyzers.

It looks like:

func run(pass *analysis.Pass) (interface{}, error) {
}

Now, to run this analyzer, we will use the singlechecker package which can be used to run a single analyzer.

package main

import (
	"github.com/agnivade/sqlargs"
	"golang.org/x/tools/go/analysis/singlechecker"
)

func main() { singlechecker.Main(sqlargs.Analyzer) }

Upon successfully compiling this, you can execute the binary as a standalone tool on your codebase: sqlargs ./....

This is the standard layout of all analyzers. Let us have a look into the internals of the run function, which is where the main code analysis is performed.

Look for SQL queries

Our primary aim is to look for expressions like db.Exec("<query>") in the code base and analyze them. This requires knowledge of Go ASTs (Abstract Syntax Tree) to slice and dice the source code and extract the stuff that we need.

To help us with scavenging the codebase and filtering the AST expressions that we need, we have some tools at our disposal, viz. the go/ast/inspector package. Using this, we just specify the node type in the source code that we are interested in and it does the rest. Since this is a very common task for all analyzers, we have an inspect pass which returns an inspector for a given pass.

Let us see how that looks like:

import (
	"golang.org/x/tools/go/analysis"
	"golang.org/x/tools/go/analysis/passes/inspect"
	"golang.org/x/tools/go/ast/inspector"
)

func run(pass *analysis.Pass) (interface{}, error) {
	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
	// We filter only function calls.
	nodeFilter := []ast.Node{
		(*ast.CallExpr)(nil),
	}

	inspect.Preorder(nodeFilter, func(n ast.Node) {
		call := n.(*ast.CallExpr)
		_ = call // work with the call expression that we have
	})
}

All expressions of the form of db.Exec("<query>") are called CallExprs. So we specify that in our nodeFilter. After that, the Preorder function will give us only CallExprs found in the codebase.

A CallExpr has two parts- Fun and Args. A Fun can either be an Ident (for example Fun()) or a SelectorExpr (for example foo.Fun()). Since we are looking for patterns like db.Exec, we need to filter only SelectorExprs.

inspect.Preorder(nodeFilter, func(n ast.Node) {
	call := n.(*ast.CallExpr)
	sel, ok := call.Fun.(*ast.SelectorExpr)
	if !ok {
		return
	}

})

Alright, so far so good. This means we have filtered all expressions of the form of type.Method() from the source code. Now we need to verify 2 things:

  1. The function name is Exec; because that is what we are interested in.
  2. The type of the selector is sql.DB. (To keep things simple, we will ignore the case when sql.DB is embedded in another struct).

Let us peek into the SelectorExpr to get these. A SelectorExpr again has two parts- X and Sel. If we take an example of db.Exec()- then db is X, and Exec is Sel. Matching the function name is easy. But to get the type info, we need to take help of analysis.Pass passed in the run function.

Pass contains a TypesInfo field which contain type information about the package. We need to use that to get the type of X and verify that the object comes from the database/sql package and is of type *sql.DB.

inspect.Preorder(nodeFilter, func(n ast.Node) {
	call := n.(*ast.CallExpr)
	sel, ok := call.Fun.(*ast.SelectorExpr)
	if !ok {
		return
	}

	// Get the type of X
	typ, ok := pass.TypesInfo.Types[sel.X]
	if !ok {
		return
	}

	t := typ.Type
	// If it is a pointer, get the element.
	if ptr, ok := t.(*types.Pointer); ok {
		t = ptr.Elem()
	}
	nTyp, ok := t.(*types.Named)
	if !ok {
		return
	}
})

Now, from nTyp we can get the type info of X and directly match the function name from Sel.

// Get the function name
sel.Sel.Name // == "Exec"

// Get the object name
nTyp.Obj().Name() // == "DB"

// Check the import of the object
nTyp.Obj().Pkg().Path() // == "database/sql"

Extract the query string

Alright ! We have successfully filtered out only expressions of type (*sql.DB).Exec. The only thing remaining is to extract the query string from the CallExpr and check it for syntax errors.

So far, we have been dealing with the Fun field of a CallExpr. To get the query string, we need to access Args. A db.Exec call will have the query string as its first param and the arguments follow after. We will get the first element of the Args slice and then use TypesInfo.Types again to get the value of the argument.

// Code continues from before.

arg0 := call.Args[0]
typ, ok := pass.TypesInfo.Types[arg0]
if !ok || typ.Value == nil {
	return
}

_ = constant.StringVal(typ.Value) // Gives us the query string ! (constant is from "go/constant")

Note that this doesn’t work if the query string is a variable. A lot of codebases have a query template string and generate the final query string dynamically. So, for example, the following will not be caught by our analyzer:

q := `SELECT %s FROM certificates WHERE date=$1;`
query := fmt.Sprintf(q, table)
db.Exec(query, date)

All that is left is for us to check the query string for syntax errors. We will use the github.com/lfittl/pg_query_go package for that. And if we get an error, pass has a Reportf helper method to print out diagnostics found during a vet pass. So:

query := constant.StringVal(typ.Value)
_, err := pg_query.Parse(query)
if err != nil {
	pass.Reportf(call.Lparen, "Invalid query: %v", err)
	return
}

The final result looks like this:

func run(pass *analysis.Pass) (interface{}, error) {
	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
	// We filter only function calls.
	nodeFilter := []ast.Node{
		(*ast.CallExpr)(nil),
	}

	inspect.Preorder(nodeFilter, func(n ast.Node) {
		call := n.(*ast.CallExpr)
		sel, ok := call.Fun.(*ast.SelectorExpr)
		if !ok {
			return
		}

		// Get the type of X
		typ, ok := pass.TypesInfo.Types[sel.X]
		if !ok {
			return
		}

		t := typ.Type
		// If it is a pointer, get the element.
		if ptr, ok := t.(*types.Pointer); ok {
			t = ptr.Elem()
		}
		nTyp, ok := t.(*types.Named)
		if !ok {
			return
		}

		if sel.Sel.Name != "Exec" &&
			nTyp.Obj().Name() != "DB" &&
			nTyp.Obj().Pkg().Path() != "database/sql" {
			return
		}

		arg0 := call.Args[0]
		typ, ok = pass.TypesInfo.Types[arg0]
		if !ok || typ.Value == nil {
			return
		}

		query := constant.StringVal(typ.Value)
		_, err := pg_query.Parse(query)
		if err != nil {
			pass.Reportf(call.Lparen, "Invalid query: %v", err)
			return
		}
	})
}

Tests

The golang.org/x/tools/go/analysis/analysistest package provides several helpers to make testing of vet passes a breeze. We just need to have our sample code that we want to test in a package. That package should reside inside the testdata folder which acts as the GOPATH for the test.

Let’s say we have a file basic.go which contains db.Exec function calls that we want to test. So the folder structure needed is:

testdata
    └── src
        └── basic
            └── basic.go

To verify expected diagnostics, we just need to add comments of the form // want ".." beside the line which is expected to throw the error. So for example, this is what the file basic.go might look like-

func runDB() {
	var db *sql.DB
	defer db.Close()

	db.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") // no error
	db.Exec(`INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"`
}

And finally to run the test, we import the analysistest package and pass our analyzer, pointing to the package that we want to test.

import (
	"testing"

	"github.com/agnivade/sqlargs"
	"golang.org/x/tools/go/analysis/analysistest"
)

func TestBasic(t *testing.T) {
	testdata := analysistest.TestData()
	analysistest.Run(t, testdata, sqlargs.Analyzer, "basic") // loads testdata/src/basic
}

That’s it !

To quickly recap-

  1. We saw the basic layout of all analyzers.
  2. We used the inspect pass to filter the AST nodes that we want.
  3. Once we got our node, we used the pass.TypesInfo.Type map to give us type information about an object.
  4. We used that to verify that the received object comes from the database/sql package and is of type *sql.DB.
  5. Then we extracted the first argument from the CallExpr and checked whether the string is a valid SQL query or not.

This was a short demo of how to go about writing a vet analyzer. Note that sql strings can also appear in other libraries like sqlx or gorm. Matching objects only with type of *sql.DB is not enough. One needs to maintain a list of type and method names to be matched. But I have kept things simple for the sake of the article. The full source code is available here. Please feel free to download and run sqlargs on your codebase. If you find a mistake in the article, please do point it out in the comments !