How to pass a parameter to a MiddlewareFunc? - rest

I'm working on a small demo that tries to explain how a basic HTTP handler works, and I found the following example:
package main
func router() *mux.Router {
router := mux.NewRouter()
auth := router.PathPrefix("/auth").Subrouter()
auth.Use(auth.ValidateToken)
auth.HandleFunc("/api", middleware.ApiHandler).Methods("GET")
return router
}
func main() {
r := router()
http.ListenAndServe(":8080", r)
}
package auth
func ValidateToken(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var header = r.Header.Get("secret-access-token")
json.NewEncoder(w).Encode(r)
header = strings.TrimSpace(header)
if header == "" {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode("Missing auth token")
return
}
if header != "SecretValue" {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode("Auth token is invalid")
return
}
json.NewEncoder(w).Encode(fmt.Sprintf("Token found. Value %s", header))
next.ServeHTTP(w, r)
})
}
package middleware
func ApiHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode("SUCCESS!")
return
}
Everything is understandable, but it pop-out to my mind two questions:
Why it doesn't require to send a parameter in this line: secure.Use(auth.ValidateToken)?
If I wanted to send an extra parameter to this auth.ValidateToken func, (example, a string, for which this header should be compared instead of "SecretValue"), how can this be done?
Thanks in advance, I'm kind of new at using golang and wanted to learn more.

auth.Use takes a function as argument, and auth.ValidateToken is the function you pass. If you want to send an argument to auth.ValidateToken, you can write a function that takes that argument, and returns a middleware function, like this:
func GetValidateTokenFunc(headerName string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var header = r.Header.Get(headerName)
...
}
}
}
Then you can do:
auth.Use(auth.GetValidateTokenFunc("headerName"))

Related

How to log only errors like 404 with chi in go?

I'm using chi with our Go webservices.
How to configure the logger (middleware) so it only logs requests that ended up with errors (like 404) but it doesn't log successful requests (with status code 200)?
Here's our current implementation (with no logging at all)
r := chi.NewRouter()
if DEBUG_LOGS {
r.Use(middleware.Logger)
} else {
}
The easiest way is to implement the logging function by yourself using the example from the chi package (for simplicity, I removed the colors).
package main
import (
"bytes"
"fmt"
"log"
"net/http"
"os"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
const DEBUG_LOGS = true
func main() {
api := &http.Server{Addr: ":8000"}
r := chi.NewRouter()
if DEBUG_LOGS {
// create default logger/zerolog/logrus
logger := log.New(os.Stdout, "", log.LstdFlags)
r.Use(middleware.RequestLogger(&StructuredLogger{logger}))
}
r.Get("/tea", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) })
r.Get("/ok", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
api.Handler = r
err := api.ListenAndServe()
if err != nil {
log.Fatal(err)
}
}
// below is the implementation of the custom logger.
type StructuredLogger struct {
Logger middleware.LoggerInterface
}
type LogEntry struct {
*StructuredLogger
request *http.Request
buf *bytes.Buffer
useColor bool
}
func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry {
entry := &LogEntry{
StructuredLogger: l,
request: r,
buf: &bytes.Buffer{},
useColor: false,
}
reqID := middleware.GetReqID(r.Context())
if reqID != "" {
fmt.Fprintf(entry.buf, "[%s] ", reqID)
}
fmt.Fprintf(entry.buf, "\"")
fmt.Fprintf(entry.buf, "%s ", r.Method)
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
fmt.Fprintf(entry.buf, "%s://%s%s %s\" ", scheme, r.Host, r.RequestURI, r.Proto)
entry.buf.WriteString("from ")
entry.buf.WriteString(r.RemoteAddr)
entry.buf.WriteString(" - ")
return entry
}
func (l *LogEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) {
// Do nothing if status code is 200/201/eg
if status < 300 {
return
}
fmt.Fprintf(l.buf, "%03d", status)
fmt.Fprintf(l.buf, " %dB", bytes)
l.buf.WriteString(" in ")
if elapsed < 500*time.Millisecond {
fmt.Fprintf(l.buf, "%s", elapsed)
} else if elapsed < 5*time.Second {
fmt.Fprintf(l.buf, "%s", elapsed)
} else {
fmt.Fprintf(l.buf, "%s", elapsed)
}
l.Logger.Print(l.buf.String())
}
func (l *LogEntry) Panic(v interface{}, stack []byte) {
middleware.PrintPrettyStack(v)
}

Temporary lock the resource until goroutine is finished

I have a method, which sends HTTP status-code 202 Accepted as indicator of successful request, and runs another process in it (goroutine). Something like that:
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
go func() {
time.Sleep(2 * time.Second)
}()
}
I want to temporarily lock the resource to prevent multiple execution of goroutine process.
return func(w http.ResponseWriter, r *http.Request) {
c := make(chan bool)
select {
case _, unlocked := <-c:
if unlocked {
break
}
default:
w.WriteHeader(http.StatusLocked)
return
}
w.WriteHeader(http.StatusAccepted)
go func(c chan bool) {
time.Sleep(2 * time.Second)
c <- true
}(c)
}
I'm get 423 Locked status code always. I think, I don't understand channel yet. May be try to use mutexes?
Not the best solution, but it works:
func (h *HookHandler) NewEvents() http.HandlerFunc {
eventsLocker := struct {
v bool
mux *sync.Mutex
}{
v: true,
mux: &sync.Mutex{},
}
return func(w http.ResponseWriter, r *http.Request) {
if !eventsLocker.v {
w.WriteHeader(http.StatusLocked)
return
}
w.WriteHeader(http.StatusAccepted)
go func() {
defer eventsLocker.mux.Unlock()
defer func() { eventsLocker.v = true }()
eventsLocker.mux.Lock()
eventsLocker.v = false
time.Sleep(2 * time.Second)
}()
}
}

How to write a generic handler in Go?

I need to create a HTTP Handler for a REST API.
This REST API have many different objects that are stored in a database (MongoDB in my case).
Currently, I need to write one handler per action per object.
I would like to find a way like it's possible with Generics to write a generic handler that could handle a specific action but for any kind of object (As basically it's just CRUD in most of the case)
How can I do this ?
Here is examples of Handlers I would like to transform into a generic one :
func IngredientIndex(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
db := r.Context().Value("Database").(*mgo.Database)
ingredients := []data.Ingredient{}
err := db.C("ingredients").Find(bson.M{}).All(&ingredients)
if err != nil {
panic(err)
}
json.NewEncoder(w).Encode(ingredients)
}
func IngredientGet(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
vars := mux.Vars(r)
logger := r.Context().Value("Logger").(zap.Logger)
db := r.Context().Value("Database").(*mgo.Database)
ingredient := data.Ingredient{}
err := db.C("ingredients").Find(bson.M{"_id": bson.ObjectIdHex(vars["id"])}).One(&ingredient)
if err != nil {
w.WriteHeader(404)
logger.Info("Fail to find entity", zap.Error(err))
} else {
json.NewEncoder(w).Encode(ingredient)
}
}
Basically I would need a handler like this (This is my tentative that doesn't work) :
func ObjectIndex(collection string, container interface{}) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
db := r.Context().Value("Database").(*mgo.Database)
objects := container
err := db.C(collection).Find(bson.M{}).All(&objects)
if err != nil {
panic(err)
}
json.NewEncoder(w).Encode(objects)
}
How can I do that ?
Use the reflect package to create a new container on each invocation:
func ObjectIndex(collection string, container interface{}) func(http.ResponseWriter, *http.Request) {
t := reflect.TypeOf(container)
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
db := r.Context().Value("Database").(*mgo.Database)
objects := reflect.New(t).Interface()
err := db.C(collection).Find(bson.M{}).All(objects)
if err != nil {
panic(err)
}
json.NewEncoder(w).Encode(objects)
}
Call it like this:
h := ObjectIndex("ingredients", data.Ingredient{})
assuming that data.Indgredient is a slice type.

How to access attribute of interface

It was my intention to use the HTTP status codes both in the header and the body of the two response structs. Bu that without setting the status code twice as function parameter and again for the struct to avoid redundancy.
The parameter response of JSON() is an interface to allow both structs to be accepted. The compiler throws the following exception:
response.Status undefined (type interface {} has no field or method Status)
because the response field must not have a status attribute. Is there an alternative way to avoid setting the status code twice?
type Response struct {
Status int `json:"status"`
Data interface{} `json:"data"`
}
type ErrorResponse struct {
Status int `json:"status"`
Errors []string `json:"errors"`
}
func JSON(rw http.ResponseWriter, response interface{}) {
payload, _ := json.MarshalIndent(response, "", " ")
rw.WriteHeader(response.Status)
...
}
The type response in rw.WriteHeader(response.Status) is interface{}. In Go, you need to explicitly assert the type of the underlying struct and then access the field:
func JSON(rw http.ResponseWriter, response interface{}) {
payload, _ := json.MarshalIndent(response, "", " ")
switch r := response.(type) {
case ErrorResponse:
rw.WriteHeader(r.Status)
case Response:
rw.WriteHeader(r.Status)
}
...
}
A better and the preferred way to do this however is to define a common interface for your responses, that has a method for getting the status of the response:
type Statuser interface {
Status() int
}
// You need to rename the fields to avoid name collision.
func (r Response) Status() int { return r.ResStatus }
func (r ErrorResponse) Status() int { return r.ResStatus }
func JSON(rw http.ResponseWriter, response Statuser) {
payload, _ := json.MarshalIndent(response, "", " ")
rw.WriteHeader(response.Status())
...
}
And it's better to rename Response to DataResponse and ResponseInterface to Response, IMO.
Interfaces don't have attributes, so you need to extract the struct from the interface. To do this you use a type assertion
if response, ok := response.(ErrorResponse); ok {
rw.WriteHeader(response.Status)
...

Accessing the underlying socket of a net/http response

I'm new to Go and evaluating it for a project.
I'm trying to write a custom handler to serve files with net/http.
I can't use the default http.FileServer() handler because I need to have access to the underlying socket (the internal net.Conn) so I can perform some informational platform specific "syscall" calls on it (mainly TCP_INFO).
More precisly: I need to access the underlying socket of the http.ResponseWriter in the handler function:
func myHandler(w http.ResponseWriter, r *http.Request) {
...
// I need the net.Conn of w
...
}
used in
http.HandleFunc("/", myHandler)
Is there a way to this. I looked at how websocket.Upgrade does this but it uses Hijack() which is 'too much' because then I have to code 'speaking http' over the raw tcp socket I get. I just want a reference to the socket and not taking over completely.
After Issue #30694 is completed, it looks like Go 1.13 will probably support storing the net.Conn in the Request Context, which makes this fairly clean and simple:
package main
import (
"net/http"
"context"
"net"
"log"
)
type contextKey struct {
key string
}
var ConnContextKey = &contextKey{"http-conn"}
func SaveConnInContext(ctx context.Context, c net.Conn) (context.Context) {
return context.WithValue(ctx, ConnContextKey, c)
}
func GetConn(r *http.Request) (net.Conn) {
return r.Context().Value(ConnContextKey).(net.Conn)
}
func main() {
http.HandleFunc("/", myHandler)
server := http.Server{
Addr: ":8080",
ConnContext: SaveConnInContext,
}
server.ListenAndServe()
}
func myHandler(w http.ResponseWriter, r *http.Request) {
conn := GetConn(r)
...
}
Until then ... For a server listening on a TCP port, net.Conn.RemoteAddr().String() is unique for each connection and is available to the http.Handler as r.RemoteAddr, so it can be used as a key to a global map of Conns:
package main
import (
"net/http"
"net"
"fmt"
"log"
)
var conns = make(map[string]net.Conn)
func ConnStateEvent(conn net.Conn, event http.ConnState) {
if event == http.StateActive {
conns[conn.RemoteAddr().String()] = conn
} else if event == http.StateHijacked || event == http.StateClosed {
delete(conns, conn.RemoteAddr().String())
}
}
func GetConn(r *http.Request) (net.Conn) {
return conns[r.RemoteAddr]
}
func main() {
http.HandleFunc("/", myHandler)
server := http.Server{
Addr: ":8080",
ConnState: ConnStateEvent,
}
server.ListenAndServe()
}
func myHandler(w http.ResponseWriter, r *http.Request) {
conn := GetConn(r)
...
}
For a server listening on a UNIX socket, net.Conn.RemoteAddr().String() is always "#", so the above doesn't work. To make this work, we can override net.Listener.Accept(), and use that to override net.Conn.RemoteAddr().String() so that it returns a unique string for each connection:
package main
import (
"net/http"
"net"
"os"
"golang.org/x/sys/unix"
"fmt"
"log"
)
func main() {
http.HandleFunc("/", myHandler)
listenPath := "/var/run/go_server.sock"
l, err := NewUnixListener(listenPath)
if err != nil {
log.Fatal(err)
}
defer os.Remove(listenPath)
server := http.Server{
ConnState: ConnStateEvent,
}
server.Serve(NewConnSaveListener(l))
}
func myHandler(w http.ResponseWriter, r *http.Request) {
conn := GetConn(r)
if unixConn, isUnix := conn.(*net.UnixConn); isUnix {
f, _ := unixConn.File()
pcred, _ := unix.GetsockoptUcred(int(f.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
f.Close()
log.Printf("Remote UID: %d", pcred.Uid)
}
}
var conns = make(map[string]net.Conn)
type connSaveListener struct {
net.Listener
}
func NewConnSaveListener(wrap net.Listener) (net.Listener) {
return connSaveListener{wrap}
}
func (self connSaveListener) Accept() (net.Conn, error) {
conn, err := self.Listener.Accept()
ptrStr := fmt.Sprintf("%d", &conn)
conns[ptrStr] = conn
return remoteAddrPtrConn{conn, ptrStr}, err
}
func GetConn(r *http.Request) (net.Conn) {
return conns[r.RemoteAddr]
}
func ConnStateEvent(conn net.Conn, event http.ConnState) {
if event == http.StateHijacked || event == http.StateClosed {
delete(conns, conn.RemoteAddr().String())
}
}
type remoteAddrPtrConn struct {
net.Conn
ptrStr string
}
func (self remoteAddrPtrConn) RemoteAddr() (net.Addr) {
return remoteAddrPtr{self.ptrStr}
}
type remoteAddrPtr struct {
ptrStr string
}
func (remoteAddrPtr) Network() (string) {
return ""
}
func (self remoteAddrPtr) String() (string) {
return self.ptrStr
}
func NewUnixListener(path string) (net.Listener, error) {
if err := unix.Unlink(path); err != nil && !os.IsNotExist(err) {
return nil, err
}
mask := unix.Umask(0777)
defer unix.Umask(mask)
l, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
if err := os.Chmod(path, 0660); err != nil {
l.Close()
return nil, err
}
return l, nil
}
Note that although in current implementation http.ResponseWriter is a *http.response (note the lowercase!) which holds the connection, the field is unexported and you can't access it.
Instead take a look at the Server.ConnState hook: you can "register" a function which will be called when the connection state changes, see http.ConnState for details. For example you will get the net.Conn even before the request enters the handler (http.StateNew and http.StateActive states).
You can install a connection state listener by creating a custom Server like this:
func main() {
http.HandleFunc("/", myHandler)
s := &http.Server{
Addr: ":8081",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
ConnState: ConnStateListener,
}
panic(s.ListenAndServe())
}
func ConnStateListener(c net.Conn, cs http.ConnState) {
fmt.Printf("CONN STATE: %v, %v\n", cs, c)
}
This way you will have exactly the desired net.Conn even before (and also during and after) invoking the handler. The downside is that it is not "paired" with the ResponseWriter, you have to do that manually if you need that.
You can use an HttpHijacker to take over the TCP connection from the ResponseWriter. Once you've done that you're free to use the socket to do whatever you want.
See http://golang.org/pkg/net/http/#Hijacker, which also contains a good example.
This can be done with reflection. it's a bit "dirty" but it works:
package main
import "net/http"
import "fmt"
import "runtime"
import "reflect"
func myHandler(w http.ResponseWriter, r *http.Request) {
ptrVal := reflect.ValueOf(w)
val := reflect.Indirect(ptrVal)
// w is a "http.response" struct from which we get the 'conn' field
valconn := val.FieldByName("conn")
val1 := reflect.Indirect(valconn)
// which is a http.conn from which we get the 'rwc' field
ptrRwc := val1.FieldByName("rwc").Elem()
rwc := reflect.Indirect(ptrRwc)
// which is net.TCPConn from which we get the embedded conn
val1conn := rwc.FieldByName("conn")
val2 := reflect.Indirect(val1conn)
// which is a net.conn from which we get the 'fd' field
fdmember := val2.FieldByName("fd")
val3 := reflect.Indirect(fdmember)
// which is a netFD from which we get the 'sysfd' field
netFdPtr := val3.FieldByName("sysfd")
fmt.Printf("netFDPtr= %v\n", netFdPtr)
// which is the system socket (type is plateform specific - Int for linux)
if runtime.GOOS == "linux" {
fd := int(netFdPtr.Int())
fmt.Printf("fd = %v\n", fd)
// fd is the socket - we can call unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd),....) on it for instance
}
fmt.Fprintf(w, "Hello World")
}
func main() {
http.HandleFunc("/", myHandler)
err := http.ListenAndServe(":8081", nil)
fmt.Println(err.Error())
}
Ideally the library should be augmented with a method to get the underlying net.Conn
Expanding on KGJV's answer, a working solution using reflection to maintain a map of connections indexed by net.Conn instance memory addresses.
Instances of net.Conn can be looked up by pointer, and pointers derived using reflection against http.Response.
It's a bit nasty, but given you can't access unpublished fields with reflection it's the only way I could see of doing it.
// Connection array indexed by connection address
var conns = make(map[uintptr]net.Conn)
var connMutex = sync.Mutex{}
// writerToConnPrt converts an http.ResponseWriter to a pointer for indexing
func writerToConnPtr(w http.ResponseWriter) uintptr {
ptrVal := reflect.ValueOf(w)
val := reflect.Indirect(ptrVal)
// http.conn
valconn := val.FieldByName("conn")
val1 := reflect.Indirect(valconn)
// net.TCPConn
ptrRwc := val1.FieldByName("rwc").Elem()
rwc := reflect.Indirect(ptrRwc)
// net.Conn
val1conn := rwc.FieldByName("conn")
val2 := reflect.Indirect(val1conn)
return val2.Addr().Pointer()
}
// connToPtr converts a net.Conn into a pointer for indexing
func connToPtr(c net.Conn) uintptr {
ptrVal := reflect.ValueOf(c)
return ptrVal.Pointer()
}
// ConnStateListener bound to server and maintains a list of connections by pointer
func ConnStateListener(c net.Conn, cs http.ConnState) {
connPtr := connToPtr(c)
connMutex.Lock()
defer connMutex.Unlock()
switch cs {
case http.StateNew:
log.Printf("CONN Opened: 0x%x\n", connPtr)
conns[connPtr] = c
case http.StateClosed:
log.Printf("CONN Closed: 0x%x\n", connPtr)
delete(conns, connPtr)
}
}
func HandleRequest(w http.ResponseWriter, r *http.Request) {
connPtr := writerToConnPtr(w)
connMutex.Lock()
defer connMutex.Unlock()
// Requests can access connections by pointer from the responseWriter object
conn, ok := conns[connPtr]
if !ok {
log.Printf("error: no matching connection found")
return
}
// Do something with connection here...
}
// Bind with http.Server.ConnState = ConnStateListener
It looks like you cannot "pair" a socket (or net.Conn) to either http.Request or http.ResponseWriter.
But you can implement your own Listener:
package main
import (
"fmt"
"net"
"net/http"
"time"
"log"
)
func main() {
// init http server
m := &MyHandler{}
s := &http.Server{
Handler: m,
}
// create custom listener
nl, err := net.Listen("tcp", ":8080")
if err != nil {
log.Fatal(err)
}
l := &MyListener{nl}
// serve through custom listener
err = s.Serve(l)
if err != nil {
log.Fatal(err)
}
}
// net.Conn
type MyConn struct {
nc net.Conn
}
func (c MyConn) Read(b []byte) (n int, err error) {
return c.nc.Read(b)
}
func (c MyConn) Write(b []byte) (n int, err error) {
return c.nc.Write(b)
}
func (c MyConn) Close() error {
return c.nc.Close()
}
func (c MyConn) LocalAddr() net.Addr {
return c.nc.LocalAddr()
}
func (c MyConn) RemoteAddr() net.Addr {
return c.nc.RemoteAddr()
}
func (c MyConn) SetDeadline(t time.Time) error {
return c.nc.SetDeadline(t)
}
func (c MyConn) SetReadDeadline(t time.Time) error {
return c.nc.SetReadDeadline(t)
}
func (c MyConn) SetWriteDeadline(t time.Time) error {
return c.nc.SetWriteDeadline(t)
}
// net.Listener
type MyListener struct {
nl net.Listener
}
func (l MyListener) Accept() (c net.Conn, err error) {
nc, err := l.nl.Accept()
if err != nil {
return nil, err
}
return MyConn{nc}, nil
}
func (l MyListener) Close() error {
return l.nl.Close()
}
func (l MyListener) Addr() net.Addr {
return l.nl.Addr()
}
// http.Handler
type MyHandler struct {
// ...
}
func (h *MyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello World")
}