1 // SPDX-License-Identifier: GPL-2.0
2 #include <vmlinux.h>
3 #include <bpf/bpf_tracing.h>
4 #include <bpf/bpf_helpers.h>
5
6 struct map_value {
7 struct prog_test_ref_kfunc __kptr *unref_ptr;
8 struct prog_test_ref_kfunc __kptr_ref *ref_ptr;
9 };
10
11 struct array_map {
12 __uint(type, BPF_MAP_TYPE_ARRAY);
13 __type(key, int);
14 __type(value, struct map_value);
15 __uint(max_entries, 1);
16 } array_map SEC(".maps");
17
18 struct hash_map {
19 __uint(type, BPF_MAP_TYPE_HASH);
20 __type(key, int);
21 __type(value, struct map_value);
22 __uint(max_entries, 1);
23 } hash_map SEC(".maps");
24
25 struct hash_malloc_map {
26 __uint(type, BPF_MAP_TYPE_HASH);
27 __type(key, int);
28 __type(value, struct map_value);
29 __uint(max_entries, 1);
30 __uint(map_flags, BPF_F_NO_PREALLOC);
31 } hash_malloc_map SEC(".maps");
32
33 struct lru_hash_map {
34 __uint(type, BPF_MAP_TYPE_LRU_HASH);
35 __type(key, int);
36 __type(value, struct map_value);
37 __uint(max_entries, 1);
38 } lru_hash_map SEC(".maps");
39
40 #define DEFINE_MAP_OF_MAP(map_type, inner_map_type, name) \
41 struct { \
42 __uint(type, map_type); \
43 __uint(max_entries, 1); \
44 __uint(key_size, sizeof(int)); \
45 __uint(value_size, sizeof(int)); \
46 __array(values, struct inner_map_type); \
47 } name SEC(".maps") = { \
48 .values = { [0] = &inner_map_type }, \
49 }
50
51 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_map, array_of_array_maps);
52 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_map, array_of_hash_maps);
53 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_malloc_map, array_of_hash_malloc_maps);
54 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, lru_hash_map, array_of_lru_hash_maps);
55 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, array_map, hash_of_array_maps);
56 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_map, hash_of_hash_maps);
57 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_malloc_map, hash_of_hash_malloc_maps);
58 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, lru_hash_map, hash_of_lru_hash_maps);
59
60 extern struct prog_test_ref_kfunc *bpf_kfunc_call_test_acquire(unsigned long *sp) __ksym;
61 extern struct prog_test_ref_kfunc *
62 bpf_kfunc_call_test_kptr_get(struct prog_test_ref_kfunc **p, int a, int b) __ksym;
63 extern void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) __ksym;
64
65 #define WRITE_ONCE(x, val) ((*(volatile typeof(x) *) &(x)) = (val))
66
test_kptr_unref(struct map_value * v)67 static void test_kptr_unref(struct map_value *v)
68 {
69 struct prog_test_ref_kfunc *p;
70
71 p = v->unref_ptr;
72 /* store untrusted_ptr_or_null_ */
73 WRITE_ONCE(v->unref_ptr, p);
74 if (!p)
75 return;
76 if (p->a + p->b > 100)
77 return;
78 /* store untrusted_ptr_ */
79 WRITE_ONCE(v->unref_ptr, p);
80 /* store NULL */
81 WRITE_ONCE(v->unref_ptr, NULL);
82 }
83
test_kptr_ref(struct map_value * v)84 static void test_kptr_ref(struct map_value *v)
85 {
86 struct prog_test_ref_kfunc *p;
87
88 p = v->ref_ptr;
89 /* store ptr_or_null_ */
90 WRITE_ONCE(v->unref_ptr, p);
91 if (!p)
92 return;
93 if (p->a + p->b > 100)
94 return;
95 /* store NULL */
96 p = bpf_kptr_xchg(&v->ref_ptr, NULL);
97 if (!p)
98 return;
99 if (p->a + p->b > 100) {
100 bpf_kfunc_call_test_release(p);
101 return;
102 }
103 /* store ptr_ */
104 WRITE_ONCE(v->unref_ptr, p);
105 bpf_kfunc_call_test_release(p);
106
107 p = bpf_kfunc_call_test_acquire(&(unsigned long){0});
108 if (!p)
109 return;
110 /* store ptr_ */
111 p = bpf_kptr_xchg(&v->ref_ptr, p);
112 if (!p)
113 return;
114 if (p->a + p->b > 100) {
115 bpf_kfunc_call_test_release(p);
116 return;
117 }
118 bpf_kfunc_call_test_release(p);
119 }
120
test_kptr_get(struct map_value * v)121 static void test_kptr_get(struct map_value *v)
122 {
123 struct prog_test_ref_kfunc *p;
124
125 p = bpf_kfunc_call_test_kptr_get(&v->ref_ptr, 0, 0);
126 if (!p)
127 return;
128 if (p->a + p->b > 100) {
129 bpf_kfunc_call_test_release(p);
130 return;
131 }
132 bpf_kfunc_call_test_release(p);
133 }
134
test_kptr(struct map_value * v)135 static void test_kptr(struct map_value *v)
136 {
137 test_kptr_unref(v);
138 test_kptr_ref(v);
139 test_kptr_get(v);
140 }
141
142 SEC("tc")
test_map_kptr(struct __sk_buff * ctx)143 int test_map_kptr(struct __sk_buff *ctx)
144 {
145 struct map_value *v;
146 int key = 0;
147
148 #define TEST(map) \
149 v = bpf_map_lookup_elem(&map, &key); \
150 if (!v) \
151 return 0; \
152 test_kptr(v)
153
154 TEST(array_map);
155 TEST(hash_map);
156 TEST(hash_malloc_map);
157 TEST(lru_hash_map);
158
159 #undef TEST
160 return 0;
161 }
162
163 SEC("tc")
test_map_in_map_kptr(struct __sk_buff * ctx)164 int test_map_in_map_kptr(struct __sk_buff *ctx)
165 {
166 struct map_value *v;
167 int key = 0;
168 void *map;
169
170 #define TEST(map_in_map) \
171 map = bpf_map_lookup_elem(&map_in_map, &key); \
172 if (!map) \
173 return 0; \
174 v = bpf_map_lookup_elem(map, &key); \
175 if (!v) \
176 return 0; \
177 test_kptr(v)
178
179 TEST(array_of_array_maps);
180 TEST(array_of_hash_maps);
181 TEST(array_of_hash_malloc_maps);
182 TEST(array_of_lru_hash_maps);
183 TEST(hash_of_array_maps);
184 TEST(hash_of_hash_maps);
185 TEST(hash_of_hash_malloc_maps);
186 TEST(hash_of_lru_hash_maps);
187
188 #undef TEST
189 return 0;
190 }
191
192 SEC("tc")
test_map_kptr_ref(struct __sk_buff * ctx)193 int test_map_kptr_ref(struct __sk_buff *ctx)
194 {
195 struct prog_test_ref_kfunc *p, *p_st;
196 unsigned long arg = 0;
197 struct map_value *v;
198 int key = 0, ret;
199
200 p = bpf_kfunc_call_test_acquire(&arg);
201 if (!p)
202 return 1;
203
204 p_st = p->next;
205 if (p_st->cnt.refs.counter != 2) {
206 ret = 2;
207 goto end;
208 }
209
210 v = bpf_map_lookup_elem(&array_map, &key);
211 if (!v) {
212 ret = 3;
213 goto end;
214 }
215
216 p = bpf_kptr_xchg(&v->ref_ptr, p);
217 if (p) {
218 ret = 4;
219 goto end;
220 }
221 if (p_st->cnt.refs.counter != 2)
222 return 5;
223
224 p = bpf_kfunc_call_test_kptr_get(&v->ref_ptr, 0, 0);
225 if (!p)
226 return 6;
227 if (p_st->cnt.refs.counter != 3) {
228 ret = 7;
229 goto end;
230 }
231 bpf_kfunc_call_test_release(p);
232 if (p_st->cnt.refs.counter != 2)
233 return 8;
234
235 p = bpf_kptr_xchg(&v->ref_ptr, NULL);
236 if (!p)
237 return 9;
238 bpf_kfunc_call_test_release(p);
239 if (p_st->cnt.refs.counter != 1)
240 return 10;
241
242 p = bpf_kfunc_call_test_acquire(&arg);
243 if (!p)
244 return 11;
245 p = bpf_kptr_xchg(&v->ref_ptr, p);
246 if (p) {
247 ret = 12;
248 goto end;
249 }
250 if (p_st->cnt.refs.counter != 2)
251 return 13;
252 /* Leave in map */
253
254 return 0;
255 end:
256 bpf_kfunc_call_test_release(p);
257 return ret;
258 }
259
260 SEC("tc")
test_map_kptr_ref2(struct __sk_buff * ctx)261 int test_map_kptr_ref2(struct __sk_buff *ctx)
262 {
263 struct prog_test_ref_kfunc *p, *p_st;
264 struct map_value *v;
265 int key = 0;
266
267 v = bpf_map_lookup_elem(&array_map, &key);
268 if (!v)
269 return 1;
270
271 p_st = v->ref_ptr;
272 if (!p_st || p_st->cnt.refs.counter != 2)
273 return 2;
274
275 p = bpf_kptr_xchg(&v->ref_ptr, NULL);
276 if (!p)
277 return 3;
278 if (p_st->cnt.refs.counter != 2) {
279 bpf_kfunc_call_test_release(p);
280 return 4;
281 }
282
283 p = bpf_kptr_xchg(&v->ref_ptr, p);
284 if (p) {
285 bpf_kfunc_call_test_release(p);
286 return 5;
287 }
288 if (p_st->cnt.refs.counter != 2)
289 return 6;
290
291 return 0;
292 }
293
294 char _license[] SEC("license") = "GPL";
295