1import io, os, re, serial, struct, time
2from errno import EPERM
3from .console import VT_ENABLED
4
5try:
6    from .pyboard import Pyboard, PyboardError, stdout_write_bytes, filesystem_command
7except ImportError:
8    import sys
9
10    sys.path.append(os.path.dirname(__file__) + "/../..")
11    from pyboard import Pyboard, PyboardError, stdout_write_bytes, filesystem_command
12
13fs_hook_cmds = {
14    "CMD_STAT": 1,
15    "CMD_ILISTDIR_START": 2,
16    "CMD_ILISTDIR_NEXT": 3,
17    "CMD_OPEN": 4,
18    "CMD_CLOSE": 5,
19    "CMD_READ": 6,
20    "CMD_WRITE": 7,
21    "CMD_SEEK": 8,
22    "CMD_REMOVE": 9,
23    "CMD_RENAME": 10,
24}
25
26fs_hook_code = """\
27import uos, uio, ustruct, micropython
28
29SEEK_SET = 0
30
31class RemoteCommand:
32    def __init__(self):
33        import uselect, usys
34        self.buf4 = bytearray(4)
35        self.fout = usys.stdout.buffer
36        self.fin = usys.stdin.buffer
37        self.poller = uselect.poll()
38        self.poller.register(self.fin, uselect.POLLIN)
39
40    def poll_in(self):
41        for _ in self.poller.ipoll(1000):
42            return
43        self.end()
44        raise Exception('timeout waiting for remote')
45
46    def rd(self, n):
47        buf = bytearray(n)
48        self.rd_into(buf, n)
49        return buf
50
51    def rd_into(self, buf, n):
52        # implement reading with a timeout in case other side disappears
53        if n == 0:
54            return
55        self.poll_in()
56        r = self.fin.readinto(buf, n)
57        if r < n:
58            mv = memoryview(buf)
59            while r < n:
60                self.poll_in()
61                r += self.fin.readinto(mv[r:], n - r)
62
63    def begin(self, type):
64        micropython.kbd_intr(-1)
65        buf4 = self.buf4
66        buf4[0] = 0x18
67        buf4[1] = type
68        self.fout.write(buf4, 2)
69        # Wait for sync byte 0x18, but don't get stuck forever
70        for i in range(30):
71            self.poller.poll(1000)
72            self.fin.readinto(buf4, 1)
73            if buf4[0] == 0x18:
74                break
75
76    def end(self):
77        micropython.kbd_intr(3)
78
79    def rd_s8(self):
80        self.rd_into(self.buf4, 1)
81        n = self.buf4[0]
82        if n & 0x80:
83            n -= 0x100
84        return n
85
86    def rd_s32(self):
87        buf4 = self.buf4
88        self.rd_into(buf4, 4)
89        n = buf4[0] | buf4[1] << 8 | buf4[2] << 16 | buf4[3] << 24
90        if buf4[3] & 0x80:
91            n -= 0x100000000
92        return n
93
94    def rd_u32(self):
95        buf4 = self.buf4
96        self.rd_into(buf4, 4)
97        return buf4[0] | buf4[1] << 8 | buf4[2] << 16 | buf4[3] << 24
98
99    def rd_bytes(self, buf):
100        # TODO if n is large (eg >256) then we may miss bytes on stdin
101        n = self.rd_s32()
102        if buf is None:
103            ret = buf = bytearray(n)
104        else:
105            ret = n
106        self.rd_into(buf, n)
107        return ret
108
109    def rd_str(self):
110        n = self.rd_s32()
111        if n == 0:
112            return ''
113        else:
114            return str(self.rd(n), 'utf8')
115
116    def wr_s8(self, i):
117        self.buf4[0] = i
118        self.fout.write(self.buf4, 1)
119
120    def wr_s32(self, i):
121        ustruct.pack_into('<i', self.buf4, 0, i)
122        self.fout.write(self.buf4)
123
124    def wr_bytes(self, b):
125        self.wr_s32(len(b))
126        self.fout.write(b)
127
128    # str and bytes act the same in MicroPython
129    wr_str = wr_bytes
130
131
132class RemoteFile(uio.IOBase):
133    def __init__(self, cmd, fd, is_text):
134        self.cmd = cmd
135        self.fd = fd
136        self.is_text = is_text
137
138    def __enter__(self):
139        return self
140
141    def __exit__(self, a, b, c):
142        self.close()
143
144    def ioctl(self, request, arg):
145        if request == 4:  # CLOSE
146            self.close()
147        return 0
148
149    def flush(self):
150        pass
151
152    def close(self):
153        if self.fd is None:
154            return
155        c = self.cmd
156        c.begin(CMD_CLOSE)
157        c.wr_s8(self.fd)
158        c.end()
159        self.fd = None
160
161    def read(self, n=-1):
162        c = self.cmd
163        c.begin(CMD_READ)
164        c.wr_s8(self.fd)
165        c.wr_s32(n)
166        data = c.rd_bytes(None)
167        c.end()
168        if self.is_text:
169            data = str(data, 'utf8')
170        else:
171            data = bytes(data)
172        return data
173
174    def readinto(self, buf):
175        c = self.cmd
176        c.begin(CMD_READ)
177        c.wr_s8(self.fd)
178        c.wr_s32(len(buf))
179        n = c.rd_bytes(buf)
180        c.end()
181        return n
182
183    def readline(self):
184        l = ''
185        while 1:
186            c = self.read(1)
187            l += c
188            if c == '\\n' or c == '':
189                return l
190
191    def readlines(self):
192        ls = []
193        while 1:
194            l = self.readline()
195            if not l:
196                return ls
197            ls.append(l)
198
199    def write(self, buf):
200        c = self.cmd
201        c.begin(CMD_WRITE)
202        c.wr_s8(self.fd)
203        c.wr_bytes(buf)
204        n = c.rd_s32()
205        c.end()
206        return n
207
208    def seek(self, n, whence=SEEK_SET):
209        c = self.cmd
210        c.begin(CMD_SEEK)
211        c.wr_s8(self.fd)
212        c.wr_s32(n)
213        c.wr_s8(whence)
214        n = c.rd_s32()
215        c.end()
216        if n < 0:
217            raise OSError(n)
218        return n
219
220
221class RemoteFS:
222    def __init__(self, cmd):
223        self.cmd = cmd
224
225    def mount(self, readonly, mkfs):
226        pass
227
228    def umount(self):
229        pass
230
231    def chdir(self, path):
232        if not path.startswith("/"):
233            path = self.path + path
234        if not path.endswith("/"):
235            path += "/"
236        if path != "/":
237            self.stat(path)
238        self.path = path
239
240    def getcwd(self):
241        return self.path
242
243    def remove(self, path):
244        c = self.cmd
245        c.begin(CMD_REMOVE)
246        c.wr_str(self.path + path)
247        res = c.rd_s32()
248        c.end()
249        if res < 0:
250            raise OSError(-res)
251
252    def rename(self, old, new):
253        c = self.cmd
254        c.begin(CMD_RENAME)
255        c.wr_str(self.path + old)
256        c.wr_str(self.path + new)
257        res = c.rd_s32()
258        c.end()
259        if res < 0:
260            raise OSError(-res)
261
262    def stat(self, path):
263        c = self.cmd
264        c.begin(CMD_STAT)
265        c.wr_str(self.path + path)
266        res = c.rd_s8()
267        if res < 0:
268            c.end()
269            raise OSError(-res)
270        mode = c.rd_u32()
271        size = c.rd_u32()
272        atime = c.rd_u32()
273        mtime = c.rd_u32()
274        ctime = c.rd_u32()
275        c.end()
276        return mode, 0, 0, 0, 0, 0, size, atime, mtime, ctime
277
278    def ilistdir(self, path):
279        c = self.cmd
280        c.begin(CMD_ILISTDIR_START)
281        c.wr_str(self.path + path)
282        res = c.rd_s8()
283        c.end()
284        if res < 0:
285            raise OSError(-res)
286        def next():
287            while True:
288                c.begin(CMD_ILISTDIR_NEXT)
289                name = c.rd_str()
290                if name:
291                    type = c.rd_u32()
292                    c.end()
293                    yield (name, type, 0)
294                else:
295                    c.end()
296                    break
297        return next()
298
299    def open(self, path, mode):
300        c = self.cmd
301        c.begin(CMD_OPEN)
302        c.wr_str(self.path + path)
303        c.wr_str(mode)
304        fd = c.rd_s8()
305        c.end()
306        if fd < 0:
307            raise OSError(-fd)
308        return RemoteFile(c, fd, mode.find('b') == -1)
309
310
311def __mount():
312    uos.mount(RemoteFS(RemoteCommand()), '/remote')
313    uos.chdir('/remote')
314"""
315
316# Apply basic compression on hook code.
317for key, value in fs_hook_cmds.items():
318    fs_hook_code = re.sub(key, str(value), fs_hook_code)
319fs_hook_code = re.sub(" *#.*$", "", fs_hook_code, flags=re.MULTILINE)
320fs_hook_code = re.sub("\n\n+", "\n", fs_hook_code)
321fs_hook_code = re.sub("    ", " ", fs_hook_code)
322fs_hook_code = re.sub("rd_", "r", fs_hook_code)
323fs_hook_code = re.sub("wr_", "w", fs_hook_code)
324fs_hook_code = re.sub("buf4", "b4", fs_hook_code)
325
326
327class PyboardCommand:
328    def __init__(self, fin, fout, path):
329        self.fin = fin
330        self.fout = fout
331        self.root = path + "/"
332        self.data_ilistdir = ["", []]
333        self.data_files = []
334
335    def rd_s8(self):
336        return struct.unpack("<b", self.fin.read(1))[0]
337
338    def rd_s32(self):
339        return struct.unpack("<i", self.fin.read(4))[0]
340
341    def rd_bytes(self):
342        n = self.rd_s32()
343        return self.fin.read(n)
344
345    def rd_str(self):
346        n = self.rd_s32()
347        if n == 0:
348            return ""
349        else:
350            return str(self.fin.read(n), "utf8")
351
352    def wr_s8(self, i):
353        self.fout.write(struct.pack("<b", i))
354
355    def wr_s32(self, i):
356        self.fout.write(struct.pack("<i", i))
357
358    def wr_u32(self, i):
359        self.fout.write(struct.pack("<I", i))
360
361    def wr_bytes(self, b):
362        self.wr_s32(len(b))
363        self.fout.write(b)
364
365    def wr_str(self, s):
366        b = bytes(s, "utf8")
367        self.wr_s32(len(b))
368        self.fout.write(b)
369
370    def log_cmd(self, msg):
371        print(f"[{msg}]", end="\r\n")
372
373    def path_check(self, path):
374        parent = os.path.realpath(self.root)
375        child = os.path.realpath(path)
376        if parent != os.path.commonpath([parent, child]):
377            raise OSError(EPERM, "")  # File is outside mounted dir
378
379    def do_stat(self):
380        path = self.root + self.rd_str()
381        # self.log_cmd(f"stat {path}")
382        try:
383            self.path_check(path)
384            stat = os.stat(path)
385        except OSError as er:
386            self.wr_s8(-abs(er.errno))
387        else:
388            self.wr_s8(0)
389            # Note: st_ino would need to be 64-bit if added here
390            self.wr_u32(stat.st_mode)
391            self.wr_u32(stat.st_size)
392            self.wr_u32(int(stat.st_atime))
393            self.wr_u32(int(stat.st_mtime))
394            self.wr_u32(int(stat.st_ctime))
395
396    def do_ilistdir_start(self):
397        path = self.root + self.rd_str()
398        try:
399            self.path_check(path)
400            self.wr_s8(0)
401        except OSError as er:
402            self.wr_s8(-abs(er.errno))
403        else:
404            self.data_ilistdir[0] = path
405            self.data_ilistdir[1] = os.listdir(path)
406
407    def do_ilistdir_next(self):
408        if self.data_ilistdir[1]:
409            entry = self.data_ilistdir[1].pop(0)
410            try:
411                stat = os.lstat(self.data_ilistdir[0] + "/" + entry)
412                mode = stat.st_mode & 0xC000
413            except OSError as er:
414                mode = 0
415            self.wr_str(entry)
416            self.wr_u32(mode)
417        else:
418            self.wr_str("")
419
420    def do_open(self):
421        path = self.root + self.rd_str()
422        mode = self.rd_str()
423        # self.log_cmd(f"open {path} {mode}")
424        try:
425            self.path_check(path)
426            f = open(path, mode)
427        except OSError as er:
428            self.wr_s8(-abs(er.errno))
429        else:
430            is_text = mode.find("b") == -1
431            try:
432                fd = self.data_files.index(None)
433                self.data_files[fd] = (f, is_text)
434            except ValueError:
435                fd = len(self.data_files)
436                self.data_files.append((f, is_text))
437            self.wr_s8(fd)
438
439    def do_close(self):
440        fd = self.rd_s8()
441        # self.log_cmd(f"close {fd}")
442        self.data_files[fd][0].close()
443        self.data_files[fd] = None
444
445    def do_read(self):
446        fd = self.rd_s8()
447        n = self.rd_s32()
448        buf = self.data_files[fd][0].read(n)
449        if self.data_files[fd][1]:
450            buf = bytes(buf, "utf8")
451        self.wr_bytes(buf)
452        # self.log_cmd(f"read {fd} {n} -> {len(buf)}")
453
454    def do_seek(self):
455        fd = self.rd_s8()
456        n = self.rd_s32()
457        whence = self.rd_s8()
458        # self.log_cmd(f"seek {fd} {n}")
459        try:
460            n = self.data_files[fd][0].seek(n, whence)
461        except io.UnsupportedOperation:
462            n = -1
463        self.wr_s32(n)
464
465    def do_write(self):
466        fd = self.rd_s8()
467        buf = self.rd_bytes()
468        if self.data_files[fd][1]:
469            buf = str(buf, "utf8")
470        n = self.data_files[fd][0].write(buf)
471        self.wr_s32(n)
472        # self.log_cmd(f"write {fd} {len(buf)} -> {n}")
473
474    def do_remove(self):
475        path = self.root + self.rd_str()
476        # self.log_cmd(f"remove {path}")
477        try:
478            self.path_check(path)
479            os.remove(path)
480            ret = 0
481        except OSError as er:
482            ret = -abs(er.errno)
483        self.wr_s32(ret)
484
485    def do_rename(self):
486        old = self.root + self.rd_str()
487        new = self.root + self.rd_str()
488        # self.log_cmd(f"rename {old} {new}")
489        try:
490            self.path_check(old)
491            self.path_check(new)
492            os.rename(old, new)
493            ret = 0
494        except OSError as er:
495            ret = -abs(er.errno)
496        self.wr_s32(ret)
497
498    cmd_table = {
499        fs_hook_cmds["CMD_STAT"]: do_stat,
500        fs_hook_cmds["CMD_ILISTDIR_START"]: do_ilistdir_start,
501        fs_hook_cmds["CMD_ILISTDIR_NEXT"]: do_ilistdir_next,
502        fs_hook_cmds["CMD_OPEN"]: do_open,
503        fs_hook_cmds["CMD_CLOSE"]: do_close,
504        fs_hook_cmds["CMD_READ"]: do_read,
505        fs_hook_cmds["CMD_WRITE"]: do_write,
506        fs_hook_cmds["CMD_SEEK"]: do_seek,
507        fs_hook_cmds["CMD_REMOVE"]: do_remove,
508        fs_hook_cmds["CMD_RENAME"]: do_rename,
509    }
510
511
512class SerialIntercept:
513    def __init__(self, serial, cmd):
514        self.orig_serial = serial
515        self.cmd = cmd
516        self.buf = b""
517        self.orig_serial.timeout = 5.0
518
519    def _check_input(self, blocking):
520        if blocking or self.orig_serial.inWaiting() > 0:
521            c = self.orig_serial.read(1)
522            if c == b"\x18":
523                # a special command
524                c = self.orig_serial.read(1)[0]
525                self.orig_serial.write(b"\x18")  # Acknowledge command
526                PyboardCommand.cmd_table[c](self.cmd)
527            elif not VT_ENABLED and c == b"\x1b":
528                # ESC code, ignore these on windows
529                esctype = self.orig_serial.read(1)
530                if esctype == b"[":  # CSI
531                    while not (0x40 < self.orig_serial.read(1)[0] < 0x7E):
532                        # Looking for "final byte" of escape sequence
533                        pass
534            else:
535                self.buf += c
536
537    @property
538    def fd(self):
539        return self.orig_serial.fd
540
541    def close(self):
542        self.orig_serial.close()
543
544    def inWaiting(self):
545        self._check_input(False)
546        return len(self.buf)
547
548    def read(self, n):
549        while len(self.buf) < n:
550            self._check_input(True)
551        out = self.buf[:n]
552        self.buf = self.buf[n:]
553        return out
554
555    def write(self, buf):
556        self.orig_serial.write(buf)
557
558
559class PyboardExtended(Pyboard):
560    def __init__(self, dev, *args, **kwargs):
561        super().__init__(dev, *args, **kwargs)
562        self.device_name = dev
563        self.mounted = False
564
565    def mount_local(self, path):
566        fout = self.serial
567        self.mounted = True
568        if self.eval('"RemoteFS" in globals()') == b"False":
569            self.exec_(fs_hook_code)
570        self.exec_("__mount()")
571        self.cmd = PyboardCommand(self.serial, fout, path)
572        self.serial = SerialIntercept(self.serial, self.cmd)
573
574    def soft_reset_with_mount(self, out_callback):
575        self.serial.write(b"\x04")
576        if not self.mounted:
577            return
578
579        # Wait for a response to the soft-reset command.
580        for i in range(10):
581            if self.serial.inWaiting():
582                break
583            time.sleep(0.05)
584        else:
585            # Device didn't respond so it wasn't in a state to do a soft reset.
586            return
587
588        out_callback(self.serial.read(1))
589        self.serial = self.serial.orig_serial
590        n = self.serial.inWaiting()
591        while n > 0:
592            buf = self.serial.read(n)
593            out_callback(buf)
594            time.sleep(0.1)
595            n = self.serial.inWaiting()
596        self.serial.write(b"\x01")
597        self.exec_(fs_hook_code)
598        self.exec_("__mount()")
599        self.exit_raw_repl()
600        self.read_until(4, b">>> ")
601        self.serial = SerialIntercept(self.serial, self.cmd)
602
603    def umount_local(self):
604        if self.mounted:
605            self.exec_('uos.umount("/remote")')
606            self.mounted = False
607