1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4MIT License
5
6Copyright (c) 2018 Alex Woo
7
8Permission is hereby granted, free of charge, to any person obtaining a copy
9of this software and associated documentation files (the "Software"), to deal
10in the Software without restriction, including without limitation the rights
11to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12copies of the Software, and to permit persons to whom the Software is
13furnished to do so, subject to the following conditions:
14
15The above copyright notice and this permission notice shall be included in all
16copies or substantial portions of the Software.
17
18THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24SOFTWARE.
25"""
26from __future__ import print_function
27import os, sys, math, time, string, struct
28
29# ymodem data header byte
30SOH = b'\x01'
31STX = b'\x02'
32EOT = b'\x04'
33ACK = b'\x06'
34NAK = b'\x15'
35CAN = b'\x18'
36CRC = b'C'
37
38class SendTask(object):
39    def __init__(self):
40        self._task_name = ""
41        self._task_size = 0
42        self._task_packets = 0
43        self._last_valid_packets_size = 0
44        self._sent_packets = 0
45        self._missing_sent_packets = 0
46        self._valid_sent_packets = 0
47        self._valid_sent_bytes = 0
48
49    def inc_sent_packets(self):
50        self._sent_packets += 1
51
52    def inc_missing_sent_packets(self):
53        self._missing_sent_packets += 1
54
55    def inc_valid_sent_packets(self):
56        self._valid_sent_packets += 1
57
58    def add_valid_sent_bytes(self, this_valid_sent_bytes):
59        self._valid_sent_bytes += this_valid_sent_bytes
60
61    def get_valid_sent_packets(self):
62        return self._valid_sent_packets
63
64    def get_valid_sent_bytes(self):
65        return self._valid_sent_bytes
66
67    def set_task_name(self, data_name):
68        self._task_name = data_name
69
70    def set_task_size(self, data_size):
71        self._task_size = data_size
72        self._task_packets = math.ceil(data_size / 128)
73        self._last_valid_packets_size = data_size % 128
74
75class ReceiveTask(object):
76    def __init__(self):
77        self._task_name = ""
78        self._task_size = 0
79        self._task_packets = 0
80        self._last_valid_packets_size = 0
81        self._received_packets = 0
82        self._missing_received_packets = 0
83        self._valid_received_packets = 0
84        self._valid_received_bytes = 0
85
86    def inc_received_packets(self):
87        self._received_packets += 1
88
89    def inc_missing_received_packets(self):
90        self._missing_received_packets += 1
91
92    def inc_valid_received_packets(self):
93        self._valid_received_packets += 1
94
95    def add_valid_received_bytes(self, this_valid_received_bytes):
96        self._valid_received_bytes += this_valid_received_bytes
97
98    def get_task_packets(self):
99        return self._task_packets
100
101    def get_last_valid_packet_size(self):
102        return self._last_valid_packets_size
103
104    def get_valid_received_packets(self):
105        return self._valid_received_packets
106
107    def get_valid_received_bytes(self):
108        return self._valid_received_bytes
109
110    def set_task_name(self, data_name):
111        self._task_name = data_name
112
113    def set_task_size(self, data_size):
114        self._task_size = data_size
115        self._task_packets = math.ceil(data_size / 128)
116        self._last_valid_packets_size = data_size % 128
117
118    def get_task_name(self):
119        return self._task_name
120
121    def get_task_size(self):
122        return self._task_size
123
124class YModem(object):
125    def __init__(self, getc, putc, header_pad=b'\x00', data_pad=b'\x1a'):
126        self.getc = getc
127        self.putc = putc
128        self.st = SendTask()
129        self.rt = ReceiveTask()
130        self.header_pad = header_pad
131        self.data_pad = data_pad
132
133    def abort(self, count=2):
134        for _ in range(count):
135            self.putc(CAN)
136
137    def send_file(self, file_path, retry=20, callback=None):
138        try:
139            file_stream = open(file_path, 'rb')
140            file_name = os.path.basename(file_path)
141            file_size = os.path.getsize(file_path)
142            file_sent = self.send(file_stream, file_name, file_size, retry, callback)
143        except IOError as e:
144            print(str(e))
145        finally:
146            file_stream.close()
147
148        print("File: " + file_name)
149        print("Size: " + str(file_sent) + "Bytes")
150        return file_sent
151
152    def wait_for_next(self, ch, timeout=60):
153        cancel_count = 0
154        tic     = time.time()
155        while True:
156            c = self.getc(1)
157            if c:
158                if c == ch:
159                    print("<<< " + hex(ord(ch)))
160                    break
161                elif c == CAN:
162                    if cancel_count == 2:
163                        return -1
164                    else:
165                        cancel_count += 1
166                else:
167                    print("Expected " + hex(ord(ch)) + ", but got " + hex(ord(c)))
168            if (time.time() - tic) >= timeout:
169                return -2
170        return 0
171
172    def send(self, data_stream, data_name, data_size, retry=20, callback=None):
173        packet_size = 128
174
175        # [<<< CRC]
176        self.wait_for_next(CRC)
177
178        # [first packet >>>]
179        header = self._make_edge_packet_header()
180
181        if len(data_name) > 100:
182            data_name = data_name[:100]
183        self.st.set_task_name(data_name)
184        data_name += bytes.decode(self.header_pad)
185
186        data_size = str(data_size)
187        if len(data_size) > 20:
188            raise Exception("Data volume is too large!")
189        self.st.set_task_size(int(data_size))
190        data_size += bytes.decode(self.header_pad)
191
192        data = data_name + data_size
193        data = data.ljust(128, bytes.decode(self.header_pad))
194
195        checksum = self._make_send_checksum(data)
196        data_for_send = header + data.encode() + checksum
197        '''
198        print("2checksum=")
199        print(data_for_send.hex())
200        print("2len %d\r\n" % len(data_for_send))
201        '''
202        self.putc(data_for_send)
203        self.st.inc_sent_packets()
204        # data_in_hexstring = "".join("%02x" % b for b in data_for_send)
205        print("Packet 0 >>>")
206
207        # [<<< ACK]
208        # [<<< CRC]
209        self.wait_for_next(ACK)
210        self.wait_for_next(CRC)
211
212        # [data packet >>>]
213        # [<<< ACK]
214        error_count = 0
215        sequence = 1
216        sequence_int = 1
217        print("file size: " + str(self.st._task_packets/8) + " KB")
218        while True:
219            data = data_stream.read(packet_size)
220
221            if not data:
222                print('\nEOF')
223                break
224
225            extracted_data_bytes = len(data)
226
227            if extracted_data_bytes <= 128:
228                packet_size = 128
229
230            header = self._make_data_packet_header(packet_size, sequence)
231            data = data.ljust(packet_size, self.data_pad)
232            checksum = self._make_send_checksum(data)
233            data_for_send = header + data + checksum
234            '''
235            print("3checksum=")
236            print(data_for_send.hex())
237            print("3len= %d\r\n" % len(data_for_send))
238            '''
239            # data_in_hexstring = "".join("%02x" % b for b in data_for_send)
240
241            while True:
242                self.putc(data_for_send)
243                self.st.inc_sent_packets()
244                print("\rPacket " + str(sequence_int) + " / " + str(self.st._task_packets) + " >>>  ", end='')
245
246                c = self.getc(1)
247                if c == ACK:
248                    print("<<< ACK", end='')
249                    self.st.inc_valid_sent_packets()
250                    self.st.add_valid_sent_bytes(extracted_data_bytes)
251                    error_count = 0
252                    break
253                else:
254                    error_count += 1
255                    self.st.inc_missing_sent_packets()
256                    print("RETRY " + str(error_count))
257
258                    if error_count > retry:
259                        self.abort()
260                        print('send error: NAK received %d , aborting' % retry)
261                        return -2
262
263            sequence = (sequence + 1) % 0x100
264            sequence_int = sequence_int + 1
265
266        # [EOT >>>]
267        # [<<< NAK]
268        # [EOT >>>]
269        # [<<< ACK]
270        # workaround retry issue in 2nd boot
271        for i in range(20):
272            self.putc(EOT)
273            print(">>> EOT")
274            if self.wait_for_next(NAK, 0.5) != -2:
275                break
276            print("to receive NAK timeout. retry...")
277        self.putc(EOT)
278        print(">>> EOT")
279        self.wait_for_next(ACK)
280
281        # [<<< CRC]
282        self.wait_for_next(CRC)
283
284        # [Final packet >>>]
285        header = self._make_edge_packet_header()
286        if sys.version_info.major == 2:
287            data = bytes.decode("").ljust(128, bytes.decode(self.header_pad))
288        else:
289            data = "".ljust(128, bytes.decode(self.header_pad))
290        checksum = self._make_send_checksum(data)
291        data_for_send = header + data.encode() + checksum
292        '''
293        print("1checksum=")
294        print(data_for_send.hex())
295        print("3len=%d\r\n" % len(data_for_send))
296        '''
297        self.putc(data_for_send)
298        self.st.inc_sent_packets()
299        print("Packet End >>>")
300
301        self.wait_for_next(ACK)
302
303        return self.st.get_valid_sent_bytes()
304
305    def wait_for_header(self):
306        cancel_count = 0
307        while True:
308            c = self.getc(1)
309            if c:
310                if c == SOH or c == STX:
311                    return c
312                elif c == CAN:
313                    if cancel_count == 2:
314                        return -1
315                    else:
316                        cancel_count += 1
317                else:
318                    print("Expected 0x01(SOH)/0x02(STX)/0x18(CAN), but got " + hex(ord(c)))
319
320    def wait_for_eot(self):
321        eot_count = 0
322        while True:
323            c = self.getc(1)
324            if c:
325                if c == EOT:
326                    eot_count += 1
327                    if eot_count == 1:
328                        print("EOT >>>")
329                        self.putc(NAK)
330                        print("<<< NAK")
331                    elif eot_count == 2:
332                        print("EOT >>>")
333                        self.putc(ACK)
334                        print("<<< ACK")
335                        self.putc(CRC)
336                        print("<<< CRC")
337                        break
338                else:
339                    print("Expected 0x04(EOT), but got " + hex(ord(c)))
340
341    def recv_file(self, root_path, callback=None):
342        while True:
343            self.putc(CRC)
344            print("<<< CRC")
345            c = self.getc(1)
346            if c:
347                if c == SOH:
348                    packet_size = 128
349                    break
350                elif c == STX:
351                    packet_size = 1024
352                    break
353                else:
354                    print("Expected 0x01(SOH)/0x02(STX)/0x18(CAN), but got " + hex(ord(c)))
355
356        IS_FIRST_PACKET = True
357        FIRST_PACKET_RECEIVED = False
358        WAIT_FOR_EOT = False
359        WAIT_FOR_END_PACKET = False
360        sequence = 0
361        sequence_int = 0
362        while True:
363            if WAIT_FOR_EOT:
364                self.wait_for_eot()
365                WAIT_FOR_EOT = False
366                WAIT_FOR_END_PACKET = True
367                sequence = 0
368                sequence_int = 0
369            else:
370                if IS_FIRST_PACKET:
371                    IS_FIRST_PACKET = False
372                else:
373                    c = self.wait_for_header()
374
375                    if c == SOH:
376                        packet_size = 128
377                    elif c == STX:
378                        packet_size = 1024
379                    else:
380                        return c
381
382                seq = self.getc(1)
383                if seq is None:
384                    seq_oc = None
385                else:
386                    seq = ord(seq)
387                    c = self.getc(1)
388                    if c is not None:
389                        seq_oc = 0xFF - ord(c)
390
391                data = self.getc(packet_size + 2)
392                if not (seq == seq_oc == sequence):
393                    continue
394                else:
395                    valid, _ = self._verify_recv_checksum(data)
396
397                    if valid:
398                        # first packet
399                        # [<<< ACK]
400                        # [<<< CRC]
401                        if seq == 0 and not FIRST_PACKET_RECEIVED and not WAIT_FOR_END_PACKET:
402                            print("Packet 0 >>>")
403                            self.putc(ACK)
404                            print("<<< ACK")
405                            self.putc(CRC)
406                            print("<<< CRC")
407                            file_name_bytes, data_size_bytes = (data[:-2]).rstrip(self.header_pad).split(self.header_pad)
408                            file_name = bytes.decode(file_name_bytes)
409                            data_size = bytes.decode(data_size_bytes)
410                            print("TASK: " + file_name + " " + data_size + "Bytes")
411                            self.rt.set_task_name(file_name)
412                            self.rt.set_task_size(int(data_size))
413                            file_stream = open(os.path.join(root_path, file_name), 'wb+')
414                            FIRST_PACKET_RECEIVED = True
415                            sequence = (sequence + 1) % 0x100
416                            sequence_int = sequence_int + 1
417
418                        # data packet
419                        # [data packet >>>]
420                        # [<<< ACK]
421                        elif not WAIT_FOR_END_PACKET:
422                            self.rt.inc_valid_received_packets()
423                            print("\rPacket " + str(sequence_int) + " >>>  ", end='')
424                            valid_data = data[:-2]
425                            # last data packet
426                            if self.rt.get_valid_received_packets() == self.rt.get_task_packets():
427                                valid_data = valid_data[:self.rt.get_last_valid_packet_size()]
428                                WAIT_FOR_EOT = True
429                            self.rt.add_valid_received_bytes(len(valid_data))
430                            file_stream.write(valid_data)
431                            self.putc(ACK)
432                            print("<<< ACK", end='')
433
434                            sequence = (sequence + 1) % 0x100
435                            sequence_int = sequence_int + 1
436
437                        # final packet
438                        # [<<< ACK]
439                        else:
440                            print("Packet End >>>")
441                            self.putc(ACK)
442                            print("<<< ACK")
443                            break
444        file_stream.close()
445        print("File: " + self.rt.get_task_name())
446        print("Size: " + str(self.rt.get_task_size()) + "Bytes")
447        return self.rt.get_valid_received_bytes()
448
449    # Header byte
450    def _make_edge_packet_header(self):
451        _bytes = [ord(SOH), 0, 0xff]
452        return bytearray(_bytes)
453
454    def _make_data_packet_header(self, packet_size, sequence):
455        assert packet_size in (128, 1024), packet_size
456        _bytes = []
457        if packet_size == 128:
458            _bytes.append(ord(SOH))
459        elif packet_size == 1024:
460            _bytes.append(ord(STX))
461        _bytes.extend([sequence, 0xff - sequence])
462        return bytearray(_bytes)
463
464    # Make check code
465    def _make_send_checksum(self, data):
466        _bytes = []
467        crc = self.calc_crc(data)
468        _bytes.extend([crc >> 8, crc & 0xff])
469        return bytearray(_bytes)
470
471    def _verify_recv_checksum(self, data):
472        _checksum = bytearray(data[-2:])
473        their_sum = (_checksum[0] << 8) + _checksum[1]
474        data = data[:-2]
475
476        our_sum = self.calc_crc(data)
477        valid = bool(their_sum == our_sum)
478        return valid, data
479
480    # For CRC algorithm
481    crctable = [
482        0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7,
483        0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef,
484        0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6,
485        0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de,
486        0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485,
487        0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d,
488        0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4,
489        0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc,
490        0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823,
491        0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b,
492        0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12,
493        0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a,
494        0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41,
495        0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49,
496        0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70,
497        0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78,
498        0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f,
499        0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067,
500        0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e,
501        0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256,
502        0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d,
503        0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405,
504        0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c,
505        0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634,
506        0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab,
507        0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3,
508        0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a,
509        0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92,
510        0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9,
511        0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1,
512        0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8,
513        0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0,
514    ]
515
516    # CRC algorithm: CCITT-0
517    def calc_crc(self, data, crc=0):
518        if sys.version_info.major == 2:
519            ba = struct.unpack("@%dB" % len(data), data)
520        else:
521            if isinstance(data, str):
522                ba = bytearray(data, 'utf-8')
523            else:
524                ba = bytearray(data)
525        for char in ba:
526            crctbl_idx = ((crc >> 8) ^ char) & 0xff
527            crc = ((crc << 8) ^ self.crctable[crctbl_idx]) & 0xffff
528        return crc & 0xffff
529