[Pkg-privacy-commits] [obfs4proxy] 24/151: Kill Obfs4Conn.isOk with fire, and replace it with a state var.

Ximin Luo infinity0 at moszumanska.debian.org
Sat Aug 22 12:59:35 UTC 2015


This is an automated email from the git hooks/post-receive script.

infinity0 pushed a commit to branch master
in repository obfs4proxy.

commit ded3f6948cc572cd57035a093bc884b56d939463
Author: Yawning Angel <yawning at schwanenlied.me>
Date:   Wed May 14 08:06:27 2014 +0000

    Kill Obfs4Conn.isOk with fire, and replace it with a state var.
---
 obfs4.go  | 145 ++++++++++++++++++++++++++++++++++++++++----------------------
 packet.go |  40 +++++++++++------
 2 files changed, 119 insertions(+), 66 deletions(-)

diff --git a/obfs4.go b/obfs4.go
index afe8967..eadcbef 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -51,6 +51,15 @@ const (
 	maxCloseInterval  = 60
 )
 
+type connState int
+
+const (
+	stateInit connState = iota
+	stateEstablished
+	stateBroken
+	stateClosed
+)
+
 // Obfs4Conn is the implementation of the net.Conn interface for obfs4
 // connections.
 type Obfs4Conn struct {
@@ -64,7 +73,7 @@ type Obfs4Conn struct {
 	receiveBuffer        bytes.Buffer
 	receiveDecodedBuffer bytes.Buffer
 
-	isOk     bool
+	state    connState
 	isServer bool
 
 	// Server side state.
@@ -111,51 +120,65 @@ func (c *Obfs4Conn) closeAfterDelay() {
 	}
 }
 
-func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) error {
+func (c *Obfs4Conn) setBroken() {
+	c.state = stateBroken
+}
+
+func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) (err error) {
 	if c.isServer {
 		panic(fmt.Sprintf("BUG: clientHandshake() called for server connection"))
 	}
 
+	defer func() {
+		if err != nil {
+			c.setBroken()
+		}
+	}()
+
 	// Generate/send the client handshake.
-	hs, err := newClientHandshake(nodeID, publicKey)
+	var hs *clientHandshake
+	var blob []byte
+	hs, err = newClientHandshake(nodeID, publicKey)
 	if err != nil {
-		return err
+		return
 	}
-	blob, err := hs.generateHandshake()
+	blob, err = hs.generateHandshake()
 	if err != nil {
-		return err
+		return
 	}
 
 	err = c.conn.SetDeadline(time.Now().Add(connectionTimeout * 2))
 	if err != nil {
-		return err
+		return
 	}
 
 	_, err = c.conn.Write(blob)
 	if err != nil {
-		return err
+		return
 	}
 
 	// Consume the server handshake.
-	hsBuf := make([]byte, serverMaxHandshakeLength)
+	var hsBuf [serverMaxHandshakeLength]byte
 	for {
-		n, err := c.conn.Read(hsBuf)
+		var n int
+		n, err = c.conn.Read(hsBuf[:])
 		if err != nil {
-			return err
+			return
 		}
 		c.receiveBuffer.Write(hsBuf[:n])
 
-		n, seed, err := hs.parseServerHandshake(c.receiveBuffer.Bytes())
+		var seed []byte
+		n, seed, err = hs.parseServerHandshake(c.receiveBuffer.Bytes())
 		if err == ErrMarkNotFoundYet {
 			continue
 		} else if err != nil {
-			return err
+			return
 		}
 		_ = c.receiveBuffer.Next(n)
 
 		err = c.conn.SetDeadline(time.Time{})
 		if err != nil {
-			return err
+			return
 		}
 
 		// Use the derived key material to intialize the link crypto.
@@ -163,37 +186,45 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
 		c.encoder = framing.NewEncoder(okm[:framing.KeyLength])
 		c.decoder = framing.NewDecoder(okm[framing.KeyLength:])
 
-		c.isOk = true
+		c.state = stateEstablished
 
 		return nil
 	}
 }
 
-func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) error {
+func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) (err error) {
 	if !c.isServer {
 		panic(fmt.Sprintf("BUG: serverHandshake() called for client connection"))
 	}
 
+	defer func() {
+		if err != nil {
+			c.setBroken()
+		}
+	}()
+
 	hs := newServerHandshake(nodeID, keypair)
-	err := c.conn.SetDeadline(time.Now().Add(connectionTimeout))
+	err = c.conn.SetDeadline(time.Now().Add(connectionTimeout))
 	if err != nil {
-		return err
+		return
 	}
 
 	// Consume the client handshake.
-	hsBuf := make([]byte, clientMaxHandshakeLength)
+	var hsBuf [clientMaxHandshakeLength]byte
 	for {
-		n, err := c.conn.Read(hsBuf)
+		var n int
+		n, err = c.conn.Read(hsBuf[:])
 		if err != nil {
-			return err
+			return
 		}
 		c.receiveBuffer.Write(hsBuf[:n])
 
-		seed, err := hs.parseClientHandshake(c.receiveBuffer.Bytes())
+		var seed []byte
+		seed, err = hs.parseClientHandshake(c.receiveBuffer.Bytes())
 		if err == ErrMarkNotFoundYet {
 			continue
 		} else if err != nil {
-			return err
+			return
 		}
 		c.receiveBuffer.Reset()
 
@@ -206,46 +237,51 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair)
 	}
 
 	// Generate/send the response.
-	blob, err := hs.generateHandshake()
+	var blob []byte
+	blob, err = hs.generateHandshake()
 	if err != nil {
-		return err
+		return
 	}
 	_, err = c.conn.Write(blob)
 	if err != nil {
-		return err
+		return
 	}
 
 	err = c.conn.SetDeadline(time.Time{})
 	if err != nil {
-		return err
+		return
 	}
 
-	c.isOk = true
+	c.state = stateEstablished
 
 	// TODO: Generate/send the PRNG seed.
 
 	return nil
 }
 
+func (c *Obfs4Conn) CanHandshake() bool {
+	return c.state == stateInit
+}
+
+func (c *Obfs4Conn) CanReadWrite() bool {
+	return c.state == stateEstablished
+}
+
 func (c *Obfs4Conn) ServerHandshake() error {
 	// Handshakes when already established are a no-op.
-	if c.isOk {
+	if c.CanReadWrite() {
 		return nil
+	} else if !c.CanHandshake() {
+		return syscall.EINVAL
 	}
 
-	// Clients handshake as part of Dial.
 	if !c.isServer {
 		panic(fmt.Sprintf("BUG: ServerHandshake() called for client connection"))
 	}
 
-	// Regardless of what happens, don't need the listener past returning from
-	// this routine.
-	defer func() {
-		c.listener = nil
-	}()
-
 	// Complete the handshake.
 	err := c.serverHandshake(c.listener.nodeID, c.listener.keyPair)
+	c.listener = nil
 	if err != nil {
 		c.closeAfterDelay()
 	}
@@ -254,7 +290,7 @@ func (c *Obfs4Conn) ServerHandshake() error {
 }
 
 func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
-	if !c.isOk {
+	if !c.CanReadWrite() {
 		return 0, syscall.EINVAL
 	}
 
@@ -272,20 +308,19 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
 }
 
 func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
-	if !c.isOk {
+	if !c.CanReadWrite() {
 		return 0, syscall.EINVAL
 	}
 
-	wrLen := 0
-
 	// If there is buffered payload from earlier Read() calls, write.
+	wrLen := 0
 	if c.receiveDecodedBuffer.Len() > 0 {
 		wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes())
 		if err != nil {
-			c.isOk = false
+			c.setBroken()
 			return int64(wrLen), err
 		} else if wrLen < int(c.receiveDecodedBuffer.Len()) {
-			c.isOk = false
+			c.setBroken()
 			return int64(wrLen), io.ErrShortWrite
 		}
 		c.receiveDecodedBuffer.Reset()
@@ -309,6 +344,17 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
 }
 
 func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
+	if !c.CanReadWrite() {
+		return 0, syscall.EINVAL
+	}
+
+	defer func() {
+		if err != nil {
+			c.setBroken()
+		}
+	}()
+
+	// XXX: Change this to write directly to c.conn skipping frameBuf.
 	chopBuf := bytes.NewBuffer(b)
 	var payload [maxPacketPayloadLength]byte
 	var frameBuf bytes.Buffer
@@ -318,7 +364,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
 		rdLen := 0
 		rdLen, err = chopBuf.Read(payload[:])
 		if err != nil {
-			c.isOk = false
 			return 0, err
 		} else if rdLen == 0 {
 			panic(fmt.Sprintf("BUG: Write(), chopping length was 0"))
@@ -327,7 +372,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
 
 		err = c.producePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0)
 		if err != nil {
-			c.isOk = false
 			return 0, err
 		}
 	}
@@ -340,20 +384,17 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
 		err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
 			uint16(padLen-headerLength))
 		if err != nil {
-			c.isOk = false
 			return 0, err
 		}
 	} else if padLen > 0 {
 		err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
 			maxPacketPayloadLength)
 		if err != nil {
-			c.isOk = false
 			return 0, err
 		}
 		err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
 			uint16(padLen))
 		if err != nil {
-			c.isOk = false
 			return 0, err
 		}
 	}
@@ -364,7 +405,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
 		// Partial writes are fatal because the frame encoder state is advanced
 		// at this point.  It's possible to keep frameBuf around, but fuck it.
 		// Someone that wants write timeouts can change this.
-		c.isOk = false
 		return 0, err
 	}
 
@@ -376,13 +416,13 @@ func (c *Obfs4Conn) Close() error {
 		return syscall.EINVAL
 	}
 
-	c.isOk = false
+	c.state = stateClosed
 
 	return c.conn.Close()
 }
 
 func (c *Obfs4Conn) LocalAddr() net.Addr {
-	if !c.isOk {
+	if c.state == stateClosed {
 		return nil
 	}
 
@@ -390,7 +430,7 @@ func (c *Obfs4Conn) LocalAddr() net.Addr {
 }
 
 func (c *Obfs4Conn) RemoteAddr() net.Addr {
-	if !c.isOk {
+	if c.state == stateClosed {
 		return nil
 	}
 
@@ -402,7 +442,7 @@ func (c *Obfs4Conn) SetDeadline(t time.Time) error {
 }
 
 func (c *Obfs4Conn) SetReadDeadline(t time.Time) error {
-	if !c.isOk {
+	if !c.CanReadWrite() {
 		return syscall.EINVAL
 	}
 
@@ -487,6 +527,7 @@ func (l *Obfs4Listener) PublicKey() string {
 	if l.keyPair == nil {
 		return ""
 	}
+
 	return l.keyPair.Public().Base64()
 }
 
diff --git a/packet.go b/packet.go
index 7b69517..339a86d 100644
--- a/packet.go
+++ b/packet.go
@@ -67,14 +67,24 @@ func (e InvalidPayloadLengthError) Error() string {
 
 var zeroPadBytes [maxPacketPaddingLength]byte
 
-func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error {
+func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) (err error) {
 	var pkt [framing.MaximumFramePayloadLength]byte
 
+	if !c.CanReadWrite() {
+		return syscall.EINVAL
+	}
+
 	if len(data)+int(padLen) > maxPacketPayloadLength {
 		panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d",
 			len(data), padLen, maxPacketPayloadLength))
 	}
 
+	defer func() {
+		if err != nil {
+			c.setBroken()
+		}
+	}()
+
 	// Packets are:
 	//   uint8_t type      packetTypePayload (0x00)
 	//   uint16_t length   Length of the payload (Big Endian).
@@ -91,31 +101,32 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe
 
 	// Encode the packet in an AEAD frame.
 	// TODO: Change Encode to write into frame directly
-	_, frame, err := c.encoder.Encode(pkt[:pktLen])
+	var frame []byte
+	_, frame, err = c.encoder.Encode(pkt[:pktLen])
 	if err != nil {
 		// All encoder errors are fatal.
-		c.isOk = false
-		return err
+		return
 	}
-	wrLen, err := w.Write(frame)
+	var wrLen int
+	wrLen, err = w.Write(frame)
 	if err != nil {
-		c.isOk = false
-		return err
+		return
 	} else if wrLen < len(frame) {
-		c.isOk = false
-		return io.ErrShortWrite
+		err = io.ErrShortWrite
+		return
 	}
 
-	return nil
+	return
 }
 
 func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
-	if !c.isOk {
+	if !c.CanReadWrite() {
 		return n, syscall.EINVAL
 	}
 
 	var buf [consumeReadSize]byte
-	rdLen, err := c.conn.Read(buf[:])
+	var rdLen int
+	rdLen, err = c.conn.Read(buf[:])
 	if err != nil {
 		return
 	}
@@ -150,7 +161,8 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
 			if payloadLen > 0 {
 				if w != nil {
 					// c.WriteTo() skips buffering in c.receiveDecodedBuffer
-					wrLen, err := w.Write(payload)
+					var wrLen int
+					wrLen, err = w.Write(payload)
 					n += wrLen
 					if err != nil {
 						break
@@ -176,7 +188,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
 
 	// All errors that reach this point are fatal.
 	if err != nil {
-		c.isOk = false
+		c.setBroken()
 	}
 
 	return

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/pkg-privacy/packages/obfs4proxy.git



More information about the Pkg-privacy-commits mailing list