Skip to content
Snippets Groups Projects
Commit 1c696c88 authored by Marcus Weiner's avatar Marcus Weiner
Browse files

Refactor OSPF checksumming

parent 6e88d331
No related branches found
No related tags found
1 merge request!2Packet/ospfv3
package fixtures
import (
"net"
"os"
"testing"
"github.com/bio-routing/bio-rd/net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcapgo"
......@@ -35,8 +35,14 @@ func Payload(raw []byte) (pl []byte, src, dst net.IP, err error) {
}
flowSrc, flowDst := packet.NetworkLayer().NetworkFlow().Endpoints()
src = net.IP(flowSrc.Raw())
dst = net.IP(flowDst.Raw())
src, err = net.IPFromBytes(flowSrc.Raw())
if err != nil {
return
}
dst, err = net.IPFromBytes(flowDst.Raw())
if err != nil {
return
}
pl = packet.NetworkLayer().LayerPayload()
return
......
......@@ -4,8 +4,8 @@ import (
"bytes"
"encoding/binary"
"fmt"
gonet "net"
"github.com/bio-routing/bio-rd/net"
"github.com/bio-routing/bio-rd/util/checksum"
"github.com/bio-routing/bio-rd/util/decode"
"github.com/bio-routing/tflow2/convert"
......@@ -42,7 +42,7 @@ const OSPFv3MessageHeaderLength = 16
const OSPFv3MessagePacketLengthAtByte = 2
const OSPFv3MessageChecksumAtByte = 12
func (x *OSPFv3Message) Serialize(out *bytes.Buffer, src, dst gonet.IP) {
func (x *OSPFv3Message) Serialize(out *bytes.Buffer, src, dst net.IP) {
buf := bytes.NewBuffer(nil)
buf.WriteByte(x.Version)
......@@ -60,7 +60,7 @@ func (x *OSPFv3Message) Serialize(out *bytes.Buffer, src, dst gonet.IP) {
length := uint16(len(data))
putUint16(data, OSPFv3MessagePacketLengthAtByte, length)
checksum := checksum.IPv6UpperLayerChecksum(src, dst, OSPFProtocolNumber, data, OSPFv3MessageChecksumAtByte)
checksum := OSPFv3Checksum(data, src, dst)
putUint16(data, OSPFv3MessageChecksumAtByte, checksum)
out.Write(data)
......@@ -70,7 +70,7 @@ func putUint16(b []byte, p int, v uint16) {
binary.BigEndian.PutUint16(b[p:p+2], v)
}
func DeserializeOSPFv3Message(buf *bytes.Buffer, src, dst gonet.IP) (*OSPFv3Message, int, error) {
func DeserializeOSPFv3Message(buf *bytes.Buffer, src, dst net.IP) (*OSPFv3Message, int, error) {
pdu := &OSPFv3Message{}
data := buf.Bytes()
......@@ -99,7 +99,7 @@ func DeserializeOSPFv3Message(buf *bytes.Buffer, src, dst gonet.IP) (*OSPFv3Mess
return nil, readBytes, fmt.Errorf("Invalid OSPF version: %d", pdu.Version)
}
expectedChecksum := checksum.IPv6UpperLayerChecksum(src, dst, OSPFProtocolNumber, data, OSPFv3MessageChecksumAtByte)
expectedChecksum := OSPFv3Checksum(data, src, dst)
if pdu.Checksum != expectedChecksum {
return nil, readBytes, fmt.Errorf("Checksum mismatch. Expected %#04x, got %#04x", expectedChecksum, pdu.Checksum)
}
......@@ -113,6 +113,12 @@ func DeserializeOSPFv3Message(buf *bytes.Buffer, src, dst gonet.IP) (*OSPFv3Mess
return pdu, readBytes, nil
}
func OSPFv3Checksum(data []byte, src, dst net.IP) uint16 {
data[12] = 0
data[13] = 0
return checksum.IPv6UpperLayerChecksum(src, dst, OSPFProtocolNumber, data)
}
func (m *OSPFv3Message) ReadBody(buf *bytes.Buffer) (int, error) {
bodyLength := m.PacketLength - OSPFv3MessageHeaderLength
var body Serializable
......
......@@ -29,14 +29,14 @@ package checksum
import (
"encoding/binary"
"net"
"github.com/bio-routing/bio-rd/net"
"github.com/bio-routing/tflow2/convert"
"golang.org/x/net/icmp"
)
func IPv6PseudoHeader(src, dst net.IP, lenght uint32, proto uint8) []byte {
header := icmp.IPv6PseudoHeader(src, dst)
header := icmp.IPv6PseudoHeader(src.ToNetIP(), dst.ToNetIP())
lenBytes := convert.Uint32Byte(uint32(lenght))
copy(header[32:36], lenBytes)
......@@ -50,7 +50,7 @@ func IPv6PseudoHeader(src, dst net.IP, lenght uint32, proto uint8) []byte {
//
// Specify the position of the checksum using sumAt.
// Use a value lower than 0 to not skip a checksum field.
func IPv6UpperLayerChecksum(src, dst net.IP, proto uint8, pl []byte, sumAt int) uint16 {
func IPv6UpperLayerChecksum(src, dst net.IP, proto uint8, pl []byte) uint16 {
header := IPv6PseudoHeader(src, dst, uint32(len(pl)), proto)
b := append(header, pl...)
......@@ -59,9 +59,6 @@ func IPv6UpperLayerChecksum(src, dst net.IP, proto uint8, pl []byte, sumAt int)
// skipping only the checksum field itself."
var chk uint32
for i := 0; i < len(b); i += 2 {
if sumAt > 0 && i == len(header)+sumAt {
continue
}
chk += uint32(binary.BigEndian.Uint16(b[i : i+2]))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment