Go Router

Go router

源码分析

先看一下最简单的 web server 服务8:

http.HandleFunc("/",func(w http.ResponseWriter,r *http.Request){
    w.write([]byte("hello world!"))
})
http.ListenAndServe(":8080",nil)

先观察一下这个,这段代码其实就是一个在本地的8080端口的http服务,当访问:127.0.0.1:8080的时候,页面就会输出 hello world!

不难发现,其实只要知道了 http.HandleFunc() 、http.ListenAndServe() 这俩个函数干了什么就知道这个服务是怎么跑起来的了。

那么先查看 HandleFunc() 源码8 hhh

我们就跳到源码里面的 HandleFunc() 里面

func HandleFunc(pattern string, handler func(ResponseWriter, *Request)){
    DefaultServeMux.HandleFunc(pattern, handler)
}

诶,这里好像继续调用 DefaultServeMux 的 HandleFunc 的方法了昂

那这个 DefaultServeMux 是什么呢?

type ServeMux struct{
    mu    sync.RWMutex
    m     map[string]muxEntry
    hosts bool
}
var DefaultServeMux = &defaultServeMux
var defaultServeMux ServeMux

找到了定义的部分,发现其实本质就是 ServeMux 结构体.

func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
	if handler == nil {
		panic("http: nil handler")
	}
	mux.Handle(pattern, HandlerFunc(handler))
}

接着,我们再进到 mux.Handle 里面

func (mux *ServeMux) Handle(pattern string, handler Handler) {
	mux.mu.Lock()
	defer mux.mu.Unlock()

	if pattern == "" {
		panic("http: invalid pattern")
	}
	if handler == nil {
		panic("http: nil handler")
	}
	if _, exist := mux.m[pattern]; exist {
		panic("http: multiple registrations for " + pattern)
	}

	if mux.m == nil {
		mux.m = make(map[string]muxEntry)
	}
	mux.m[pattern] = muxEntry{h: handler, pattern: pattern}

	if pattern[0] != '/' {
		mux.hosts = true
	}
}

这仔细一看,这就是路由注册的实现啊!好像没有陌生的函数/方法可以跳了8

所以咱们可以断定,这 http.HandleFunc() 其实就是什么?就是注册路由的部分。

那,处理请求的部分在哪儿啊?诶,别忘了,我们还有 http.ListenAndServe() 没看呢

上源码!

func ListenAndServe(addr string, handler Handler) error {
	server := &Server{Addr: addr, Handler: handler}
	return server.ListenAndServe()
}

首先实例化了一个 Serve 结构体吧,这个不难看懂,然后 server.ListenAndServe() 不懂了, 跳!

func (srv *Server) ListenAndServe() error {
	if srv.shuttingDown() {
		return ErrServerClosed
	}
	addr := srv.Addr
	if addr == "" {
		addr = ":http"
	}
	ln, err := net.Listen("tcp", addr)
	if err != nil {
		return err
	}
	return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
}

还是不太懂吧,那跳到 srv.Serve 里面看看

// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
func (srv *Server) Serve(l net.Listener) error {
	if fn := testHookServerServe; fn != nil {
		fn(srv, l) // call hook with unwrapped listener
	}

	l = &onceCloseListener{Listener: l}
	defer l.Close()

	if err := srv.setupHTTP2_Serve(); err != nil {
		return err
	}

	if !srv.trackListener(&l, true) {
		return ErrServerClosed
	}
	defer srv.trackListener(&l, false)

	var tempDelay time.Duration     // how long to sleep on accept failure
	baseCtx := context.Background() // base is always background, per Issue 16220
	ctx := context.WithValue(baseCtx, ServerContextKey, srv)
	for {
		rw, e := l.Accept()
		if e != nil {
			select {
			case <-srv.getDoneChan():
				return ErrServerClosed
			default:
			}
			if ne, ok := e.(net.Error); ok && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}
				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}
				srv.logf("http: Accept error: %v; retrying in %v", e, tempDelay)
				time.Sleep(tempDelay)
				continue
			}
			return e
		}
		tempDelay = 0
		c := srv.newConn(rw)
		c.setState(c.rwc, StateNew) // before Serve can return
		go c.serve(ctx)
	}
}

你看注释,说Serve为每个请求开启了一个 goroutine 吧,就是 go.c.serve(ctx) 这个部分吧,那咱们就看看每个 gorountine 都干了啥呗,跳!

func (c *conn) serve(ctx context.Context) {
	c.remoteAddr = c.rwc.RemoteAddr().String()
	ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
	defer func() {
		if err := recover(); err != nil && err != ErrAbortHandler {
			const size = 64 << 10
			buf := make([]byte, size)
			buf = buf[:runtime.Stack(buf, false)]
			c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
		}
		if !c.hijacked() {
			c.close()
			c.setState(c.rwc, StateClosed)
		}
	}()

	if tlsConn, ok := c.rwc.(*tls.Conn); ok {
		if d := c.server.ReadTimeout; d != 0 {
			c.rwc.SetReadDeadline(time.Now().Add(d))
		}
		if d := c.server.WriteTimeout; d != 0 {
			c.rwc.SetWriteDeadline(time.Now().Add(d))
		}
		if err := tlsConn.Handshake(); err != nil {
			c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
			return
		}
		c.tlsState = new(tls.ConnectionState)
		*c.tlsState = tlsConn.ConnectionState()
		if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
			if fn := c.server.TLSNextProto[proto]; fn != nil {
				h := initNPNRequest{tlsConn, serverHandler{c.server}}
				fn(c.server, tlsConn, h)
			}
			return
		}
	}

	// HTTP/1.x from here on.

	ctx, cancelCtx := context.WithCancel(ctx)
	c.cancelCtx = cancelCtx
	defer cancelCtx()

	c.r = &connReader{conn: c}
	c.bufr = newBufioReader(c.r)
	c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)

	for {
		w, err := c.readRequest(ctx)
		if c.r.remain != c.server.initialReadLimitSize() {
			// If we read any bytes off the wire, we're active.
			c.setState(c.rwc, StateActive)
		}
		if err != nil {
			const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"

			if err == errTooLarge {
				// Their HTTP client may or may not be
				// able to read this if we're
				// responding to them and hanging up
				// while they're still writing their
				// request. Undefined behavior.
				const publicErr = "431 Request Header Fields Too Large"
				fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
				c.closeWriteAndWait()
				return
			}
			if isCommonNetReadError(err) {
				return // don't reply
			}

			publicErr := "400 Bad Request"
			if v, ok := err.(badRequestError); ok {
				publicErr = publicErr + ": " + string(v)
			}

			fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
			return
		}

		// Expect 100 Continue support
		req := w.req
		if req.expectsContinue() {
			if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
				// Wrap the Body reader with one that replies on the connection
				req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
			}
		} else if req.Header.get("Expect") != "" {
			w.sendExpectationFailed()
			return
		}

		c.curReq.Store(w)

		if requestBodyRemains(req.Body) {
			registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
		} else {
			w.conn.r.startBackgroundRead()
		}

		// HTTP cannot have multiple simultaneous active requests.[*]
		// Until the server replies to this request, it can't read another,
		// so we might as well run the handler in this goroutine.
		// [*] Not strictly true: HTTP pipelining. We could let them all process
		// in parallel even if their responses need to be serialized.
		// But we're not going to implement HTTP pipelining because it
		// was never deployed in the wild and the answer is HTTP/2.
		serverHandler{c.server}.ServeHTTP(w, w.req)
		w.cancelCtx()
		if c.hijacked() {
			return
		}
		w.finishRequest()
		if !w.shouldReuseConnection() {
			if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
				c.closeWriteAndWait()
			}
			return
		}
		c.setState(c.rwc, StateIdle)
		c.curReq.Store((*response)(nil))

		if !w.conn.server.doKeepAlives() {
			// We're in shutdown mode. We might've replied
			// to the user without "Connection: close" and
			// they might think they can send another
			// request, but such is life with HTTP/1.1.
			return
		}

		if d := c.server.idleTimeout(); d != 0 {
			c.rwc.SetReadDeadline(time.Now().Add(d))
			if _, err := c.bufr.Peek(4); err != nil {
				return
			}
		}
		c.rwc.SetReadDeadline(time.Time{})
	}
}

这真是太长了hh, 但是咱们只关注最重要的部分就可以了(在里面找ServeHTTP),比如,这个其实就是调用了什么?

就是调用了每个链接(conn)服务(serve)的 ServeHTTP() 函数!

总结一下

其实 HandleFunc() 就是什么呢,其实它就是注册路由的部分,不是什么处理请求的部分,处理请求的部分是什么?是 ServeHTTP() 啊,那是哪个调用到了 ServeHTTP() 呢?是 ListenAndServe() 啊!那,所以要先注册路由,因为在执行 ServeHTTP() 的时候呢,我得看路由表提供服务啊!!

那啥叫看路由表提供服务啊?那俺们先说说啥是路由表吧,好8 ?路由表是 ServeMux 结构体里的一个字段,啥字段? 就是这个 map[string]muxEntry 。它存储的是什么呢?是 pattern 和 该 pattern 对应的处理器(就是方法) 。

动手撸一个自己的路由

package route

import (
    "net/http"
    "strings"
)

//返回一个Router实例
func NewRouter() *Router {
    return new(Router)
}

//路由结构体,包含一个记录方法、路径的map
type Router struct {
    Route map[string]map[string]http.HandlerFunc
}

//实现Handler接口,匹配方法以及路径
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    if h, ok := r.Route[req.Method][req.URL.String()]; ok {
        h(w, req)
    }
}

//根据方法、路径将方法注册到路由
func (r *Router) HandleFunc(method, path string, f http.HandlerFunc) {
    method = strings.ToUpper(method)
    if r.Route == nil {
        r.Route = make(map[string]map[string]http.HandlerFunc)
    }
    if r.Route[method] == nil {
        r.Route[method] = make(map[string]http.HandlerFunc)
    }
    r.Route[method][path] = f
}

使用:

r := route.NewRouter()
r.HandleFunc("GET", "/", func(w http.ResponseWriter, r *http.Request) {
    fmt.Fprint(w, "Hello Get!")
})
r.HandleFunc("POST", "/", func(w http.ResponseWriter, r *http.Request) {
    fmt.Fprint(w, "hello POST!")
})
http.ListenAndServe(":8080", r)