I'm working on setting up an app backend that's deployed automatically using infrastructure-as-code and a ci/cd pipeline. As a result of that I need the app server to gather DB credentials automatically. This is my first time working with the AWS Go SDK and I want to get the code as close to production ready as possible.
Credentials to use the AWS SDK are set when instances are spun-up so this only needs to set a base parameters for the file root (e.g. /app/db/
) and the aws region to connect which are both pulled form a .env file. The sdk is then used to pull encrypted strings from the AWS SSM Parameter Store and use them to write a connection string and connect to the database. I also included a small server that pings the db and tells the user if it's connected.
dbConnector.go:
package main
import (
"context"
"time"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/jmoiron/sqlx"
)
const timeout = 5
// Database abstracts sqlx connection
type Database struct{
*sqlx.DB
}
// ConnectToDB creates a db connection with a predefined timeout
func ConnectToDB(ctx context.Context) (*Database, error) {
ctx, cancelFn := context.WithTimeout(ctx, timeout*time.Second)
defer cancelFn()
conn, err := sqlx.ConnectContext(ctx, "pgx", configString(ctx))
if err != nil {
return nil, err
}
return &Database{conn}, nil
}
// Connected pings server and returns bool response status
func (db *Database) Connected(ctx context.Context) bool {
err := db.PingContext(ctx)
if err != nil {
return false
}
return true
}
dbConfig.go:
package main
import (
"context"
"fmt"
"log"
"strconv"
"github.com/spf13/viper"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ssm"
)
const (
baseRegion = "AWS_REGION"
baseRoot = "AWS_ROOT"
baseConfig = "base_config"
basePath = "."
ssmHost = "host"
ssmDB = "postgres"
ssmPort = "port"
ssmUser = "user"
ssmPassword = "password"
)
type awsSSM struct {
*ssm.SSM
}
// Create database connection string based on AWS_ROOT and remote SSM parameters
func configString(ctx context.Context) string {
loadBaseConfig()
sess := session.New()
svc := awsSSM{ssm.New(sess, &aws.Config{
Region: aws.String(viper.GetString(baseRegion)),
})}
configString := fmt.Sprintf("database=%s host=%s port=%d user=%s password = %s",
ssmDB,
svc.param(ctx, ssmHost),
atoui(svc.param(ctx, ssmPort)),
svc.param(ctx, ssmUser),
svc.param(ctx, ssmPassword),
)
return configString
}
// Pull AWS_ROOT and AWS_REGION from .env file, all other config happens on instance spin-up
func loadBaseConfig() {
viper.SetConfigName(baseConfig)
viper.AddConfigPath(basePath)
err := viper.ReadInConfig()
if err != nil {
log.Println(err)
}
}
// Get SSM Parameter with context timeout
func (svc *awsSSM) param(ctx context.Context, p string) string {
output, err := svc.GetParameterWithContext(ctx, &ssm.GetParameterInput{
Name: aws.String(viper.GetString(baseRoot) + p),
WithDecryption: aws.Bool(true),
})
if err != nil {
log.Println(err)
return ""
}
return aws.StringValue(output.Parameter.Value)
}
// String to unsigned int16
func atoui(s string) uint16 {
n, err := strconv.ParseUint(s, 10, 64)
if err != nil {
log.Println(err)
return 0
}
return (uint16)(n)
}
main.go:
package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"github.com/julienschmidt/httprouter"
)
// Server struct for storing database, mux, and logger
type Server struct{
db *Database
mux *httprouter.Router
log *log.Logger
}
// Create router and environment then serve
func main() {
s := Server{
log: log.New(os.Stdout, log.Prefix(), log.Flags()),
mux: httprouter.New(),
}
db, err := ConnectToDB(context.Background())
if err != nil {
s.log.Println(err)
}
s.db = db
s.mux.GET("/", s.index())
log.Fatal(http.ListenAndServe(":8050", s.mux))
}
// Index is a closure that returns a function that checks the database connection and writes status to user
func (s *Server) index() httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
if s.db.Connected(r.Context()) {
fmt.Fprint(w, "Connected!")
} else {
http.Error(w, "Error connecting to database", http.StatusNotFound)
s.log.Println("Error connectiong to database")
}
}
}