1#!/usr/bin/env python3
2#
3# SPDX-License-Identifier: BSD-3-Clause
4# SPDX-FileCopyrightText: Copyright TF-RMM Contributors.
5# SPDX-FileCopyrightText: Copyright Arm Limited and Contributors.
6#
7
8from argparse import ArgumentParser
9import codecs
10import os
11import re
12import sys
13import logging
14from os import access, R_OK
15from os.path import isfile
16
17INCLUDE_RE = re.compile(r"^\s*#\s*include\s\s*(?P<path>[\"<].+[\">])")
18
19# exit program with rc
20def print_error_and_exit(total_errors):
21    if total_errors:
22        print("total: " + str(total_errors) + " errors")
23        sys.exit(1)
24    else:
25        sys.exit(0)
26
27def include_paths(lines):
28    """List all include paths in a file. Ignore starting `+` in diff mode."""
29    pattern = INCLUDE_RE
30    matches = (pattern.match(line) for line in lines)
31    return [m.group("path") for m in matches if m]
32
33# check if 'file' is a regular file and it is readable
34def file_readable(file):
35    if not isfile(file):
36        print(file + ": WARNING: File not found")
37        return 0
38
39    if not access(file, R_OK):
40        print(file + ": WARNING: File not readable")
41        return 0
42
43    return 1
44
45def file_include_list(path):
46    """Return a list of all include paths in a file or None on failure."""
47    try:
48        with codecs.open(path, encoding="utf-8") as f:
49            return include_paths(f)
50    except Exception:
51        logging.exception(path + ": ERROR while parsing.")
52        return ([])
53
54def check_includes(file):
55    """Checks whether the order of includes in the file specified in the path
56    is correct or not."""
57    print("Checking file: " + file)
58    if not file_readable(file):
59        return 0
60
61    inc_list = file_include_list(file)
62
63    # If there are less than 2 includes there's no need to check.
64    if len(inc_list) < 2:
65        return 0
66
67    # remove leading and trailing <, >
68    inc_list = [x[1:-1] for x in inc_list]
69
70    if sorted(inc_list) != inc_list:
71        print(file + ": ERROR: includes not in order. Include order should be " +
72              ', '.join(sorted(inc_list)))
73        return 1
74    else:
75        return 0
76
77if __name__ == "__main__":
78    ap = ArgumentParser(description='Check #include orders')
79    ap.add_argument('files', nargs='*', help='Check files.')
80    args = ap.parse_args()
81
82    total_errors = 0
83    for file in args.files:
84        total_errors += check_includes(file)
85
86    print_error_and_exit(total_errors)
87