1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2016, Linaro Limited
4  */
5 
6 #include <stdint.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <sys/param.h>
10 #include <sys/types.h>
11 
12 #include "rand_stream.h"
13 
14 #define STREAM_BUF_MIN_SIZE	4
15 
16 struct rand_stream {
17 	int32_t seed;
18 	uint8_t word_buf[4];
19 	size_t w_offs;
20 	size_t sb_size;
21 	size_t sb_offs;
22 	uint8_t stream_buf[];
23 };
24 
rand_stream_alloc(int seed,size_t stream_buffer_size)25 struct rand_stream *rand_stream_alloc(int seed, size_t stream_buffer_size)
26 {
27 	size_t sb_size = MAX(stream_buffer_size, STREAM_BUF_MIN_SIZE);
28 	struct rand_stream *rs = calloc(1, sizeof(*rs) + sb_size);
29 
30 	if (!rs)
31 		return NULL;
32 
33 	rs->sb_size = sb_size;;
34 	rs->sb_offs = rs->sb_size;
35 	rs->w_offs = sizeof(rs->word_buf);
36 	rs->seed = seed;
37 
38 	return rs;
39 }
40 
rand_stream_free(struct rand_stream * rs)41 void rand_stream_free(struct rand_stream *rs)
42 {
43 	free(rs);
44 }
45 
get_random(struct rand_stream * rs,uint8_t * buf,size_t blen)46 static void get_random(struct rand_stream *rs, uint8_t *buf, size_t blen)
47 {
48 	uint8_t *b = buf;
49 	size_t l = blen;
50 
51 
52 	/*
53 	 * This function uses an LCG,
54 	 * https://en.wikipedia.org/wiki/Linear_congruential_generator
55 	 * to generate the byte stream.
56 	 */
57 
58 	while (l) {
59 		size_t t = MIN(sizeof(rs->word_buf) - rs->w_offs, l);
60 
61 		memcpy(b, rs->word_buf + rs->w_offs, t);
62 		rs->w_offs += t;
63 		l -= t;
64 		b += t;
65 
66 		if (rs->w_offs == sizeof(rs->word_buf)) {
67 			rs->seed = rs->seed * 1103515245 + 12345;
68 			memcpy(rs->word_buf, &rs->seed, sizeof(rs->seed));
69 			rs->w_offs = 0;
70 		}
71 	}
72 }
73 
rand_stream_peek(struct rand_stream * rs,size_t * num_bytes)74 const void *rand_stream_peek(struct rand_stream *rs, size_t *num_bytes)
75 {
76 	if (rs->sb_offs == rs->sb_size) {
77 		rs->sb_offs = 0;
78 		get_random(rs, rs->stream_buf, rs->sb_size);
79 	}
80 
81 	*num_bytes = MIN(*num_bytes, rs->sb_size - rs->sb_offs);
82 	return rs->stream_buf + rs->sb_offs;
83 }
84 
rand_stream_read(struct rand_stream * rs,void * buf,size_t num_bytes)85 void rand_stream_read(struct rand_stream *rs, void *buf, size_t num_bytes)
86 {
87 	size_t peek_bytes = num_bytes;
88 	const void *peek = rand_stream_peek(rs, &peek_bytes);
89 
90 	memcpy(buf, peek, peek_bytes);
91 	rand_stream_advance(rs, peek_bytes);
92 
93 	if (num_bytes - peek_bytes)
94 		get_random(rs, (uint8_t *)buf + peek_bytes,
95 			   num_bytes - peek_bytes);
96 }
97 
rand_stream_advance(struct rand_stream * rs,size_t num_bytes)98 void rand_stream_advance(struct rand_stream *rs, size_t num_bytes)
99 {
100 	size_t nb = num_bytes;
101 
102 	if (nb <= (rs->sb_size - rs->sb_offs)) {
103 		rs->sb_offs += nb;
104 		return;
105 	}
106 
107 	nb -= rs->sb_size - rs->sb_offs;
108 	rs->sb_offs = rs->sb_size;
109 
110 	while (nb > rs->sb_size) {
111 		get_random(rs, rs->stream_buf, rs->sb_size);
112 		nb -= rs->sb_size;
113 	}
114 
115 	get_random(rs, rs->stream_buf, rs->sb_size);
116 	rs->sb_offs = nb;
117 }
118