4

I am writing code to calculate the checksum of IP and TCP Headers to see if the packet is malformed. I have coded the IP checksum successfully but when I used the code for the TCP checksum it doesn't give the correct checksums. Where is my error?

void cPacket::CheckIfMalformed()
{
    if (Packet->isIPPacket)
    {
        IP_HEADER ipheader;
        memcpy(&ipheader,(void*)&Packet->IPHeader,sizeof(IP_HEADER));
        ipheader.ip_checksum =0;
        cout << (DWORD*)Checksum((USHORT*)&ipheader,sizeof(IP_HEADER)) << endl;
        if(Checksum((USHORT*)&ipheader,sizeof(IP_HEADER)) != Packet->IPHeader.Checksum)
        {
            Packet->isMalformed = true;
            Packet->PacketError = PACKET_IP_CHECKSUM;
        }   
        else if (Packet->isTCPPacket)
        {
            TCP_HEADER tcpheader;
            memcpy(&tcpheader,(void*)&Packet->TCPHeader,sizeof(TCP_HEADER));
            tcpheader.checksum = 0;

            PSEUDO_HEADER psheader;
            psheader.daddr = Packet->IPHeader.DestinationAddress;
            psheader.saddr = Packet->IPHeader.SourceAddress;
            psheader.protocol = Packet->IPHeader.Protocol;
            psheader.length = ntohs(Packet->IPHeader.TotalLength) - Packet->IPHeader.HeaderLength*4;
            psheader.zero = 0;

            unsigned char *tcppacket;
            tcppacket = (unsigned char*)malloc(Packet->Size + sizeof(PSEUDO_HEADER));
            memset(tcppacket,0, Packet->Size + sizeof(PSEUDO_HEADER));
            memcpy(&tcppacket[0], &psheader, sizeof(PSEUDO_HEADER));
            memcpy(&tcppacket[12], &tcpheader, ntohs(Packet->IPHeader.TotalLength) - Packet->IPHeader.HeaderLength*4 );

            cout << (DWORD*)ntohs(Checksum((USHORT*)tcppacket,Packet->Size + sizeof(PSEUDO_HEADER))) << endl;
        }
    }
};

USHORT cPacket::Checksum(USHORT *buffer, unsigned int length)
{
    register int sum = 0;
    USHORT answer = 0;
    register USHORT *w = buffer;
    register int nleft = length;

    while(nleft > 1){
    sum += *w++;
    nleft -= 2;
    }

    sum = (sum >> 16) + (sum & 0xFFFF);
    sum += (sum >> 16);
    answer = ~sum;
    return(answer);
}
4

0 回答 0