Commit 2269477f authored by Stefan Westerfeld's avatar Stefan Westerfeld

Get rid of single global key for PRNG.

Signed-off-by: Stefan Westerfeld's avatarStefan Westerfeld <stefan@space.twc.de>
parent 21452170
......@@ -291,7 +291,7 @@ test_snr (const string& orig_file, const string& wm_file)
}
int
test_clip (const string& in_file, const string& out_file, int seed, int time_seconds)
test_clip (const Key& key, const string& in_file, const string& out_file, int seed, int time_seconds)
{
WavData in_data;
Error err = in_data.load (in_file);
......@@ -301,7 +301,7 @@ test_clip (const string& in_file, const string& out_file, int seed, int time_sec
return 1;
}
bool done = false;
Random rng (seed, /* there is no stream for this test */ Random::Stream::data_up_down);
Random rng (key, seed, /* there is no stream for this test */ Random::Stream::data_up_down);
size_t start_point, end_point;
do
{
......@@ -330,9 +330,9 @@ test_clip (const string& in_file, const string& out_file, int seed, int time_sec
}
int
test_speed (int seed)
test_speed (const Key& key, int seed)
{
Random rng (seed, /* there is no stream for this test */ Random::Stream::data_up_down);
Random rng (key, seed, /* there is no stream for this test */ Random::Stream::data_up_down);
double low = 0.85;
double high = 1.15;
printf ("%.6f\n", low + (rng() / double (UINT64_MAX)) * (high - low));
......@@ -340,13 +340,13 @@ test_speed (int seed)
}
int
test_gen_noise (const string& out_file, double seconds, int rate)
test_gen_noise (const Key& key, const string& out_file, double seconds, int rate)
{
const int channels = 2;
const int bits = 16;
vector<float> noise;
Random rng (0, /* there is no stream for this test */ Random::Stream::data_up_down);
Random rng (key, 0, /* there is no stream for this test */ Random::Stream::data_up_down);
for (size_t i = 0; i < size_t (rate * seconds) * channels; i++)
noise.push_back (rng.random_double() * 2 - 1);
......@@ -540,21 +540,10 @@ parse_shared_options (ArgParser& ap)
{
int i;
float f;
string s;
if (ap.parse_opt ("--strength", f))
{
Params::water_delta = f / 1000;
}
if (ap.parse_opt ("--key", s))
{
Params::have_key++;
Random::load_global_key (s);
}
if (ap.parse_opt ("--test-key", i))
{
Params::have_key++;
Random::set_global_test_key (i);
}
if (ap.parse_opt ("--short", i))
{
Params::payload_size = i;
......@@ -570,11 +559,31 @@ parse_shared_options (ArgParser& ap)
{
Params::mix = false;
}
if (Params::have_key > 1)
}
Key
parse_key (ArgParser& ap)
{
Key key; // default initialized with zero key
string s;
int i;
int have_key = 0;
if (ap.parse_opt ("--key", s))
{
have_key++;
key.load_key (s);
}
if (ap.parse_opt ("--test-key", i))
{
have_key++;
key.set_test_key (i);
}
if (have_key > 1)
{
error ("audiowmark: watermark key can at most be set once (--key / --test-key option)\n");
exit (1);
}
return key;
}
void
......@@ -778,8 +787,9 @@ main (int argc, char **argv)
ap.parse_opt ("--bit-rate", Params::hls_bit_rate);
Key key = parse_key (ap);
args = parse_positional (ap, "input_ts", "output_ts", "message_hex");
return hls_add (args[0], args[1], args[2]);
return hls_add (key, args[0], args[1], args[2]);
}
else if (ap.parse_cmd ("hls-prepare"))
{
......@@ -793,16 +803,18 @@ main (int argc, char **argv)
parse_shared_options (ap);
parse_add_options (ap);
Key key = parse_key (ap);
args = parse_positional (ap, "input_wav", "watermarked_wav", "message_hex");
return add_watermark (args[0], args[1], args[2]);
return add_watermark (key, args[0], args[1], args[2]);
}
else if (ap.parse_cmd ("get"))
{
parse_shared_options (ap);
parse_get_options (ap);
Key key = parse_key (ap); // TODO: key list
args = parse_positional (ap, "watermarked_wav");
return get_watermark (args[0], /* no ber */ "");
return get_watermark (key, args[0], /* no ber */ "");
}
else if (ap.parse_cmd ("cmp"))
{
......@@ -811,8 +823,9 @@ main (int argc, char **argv)
ap.parse_opt ("--expect-matches", Params::expect_matches);
Key key = parse_key (ap); // TODO: key list
args = parse_positional (ap, "watermarked_wav", "message_hex");
return get_watermark (args[0], args[1]);
return get_watermark (key, args[0], args[1]);
}
else if (ap.parse_cmd ("gen-key"))
{
......@@ -843,22 +856,25 @@ main (int argc, char **argv)
{
parse_shared_options (ap);
Key key = parse_key (ap);
args = parse_positional (ap, "input_wav", "output_wav", "seed", "seconds");
return test_clip (args[0], args[1], atoi (args[2].c_str()), atoi (args[3].c_str()));
return test_clip (key, args[0], args[1], atoi (args[2].c_str()), atoi (args[3].c_str()));
}
else if (ap.parse_cmd ("test-speed"))
{
parse_shared_options (ap);
Key key = parse_key (ap);
args = parse_positional (ap, "seed");
return test_speed (atoi (args[0].c_str()));
return test_speed (key, atoi (args[0].c_str()));
}
else if (ap.parse_cmd ("test-gen-noise"))
{
parse_shared_options (ap);
Key key = parse_key (ap);
args = parse_positional (ap, "output_wav", "seconds", "sample_rate");
return test_gen_noise (args[0], atof (args[1].c_str()), atoi (args[2].c_str()));
return test_gen_noise (key, args[0], atof (args[1].c_str()), atoi (args[2].c_str()));
}
else if (ap.parse_cmd ("test-change-speed"))
{
......
......@@ -47,7 +47,7 @@ hls_prepare (const string& in_dir, const string& out_dir, const string& filename
}
int
hls_add (const string& infile, const string& outfile, const string& bits)
hls_add (const Key& key, const string& infile, const string& outfile, const string& bits)
{
error ("audiowmark: hls support is not available in this build of audiowmark\n");
return 1;
......@@ -201,7 +201,7 @@ ff_decode (const string& filename, WavData& out_wav_data)
}
int
hls_add (const string& infile, const string& outfile, const string& bits)
hls_add (const Key& key, const string& infile, const string& outfile, const string& bits)
{
TSReader reader;
......@@ -276,7 +276,7 @@ hls_add (const string& infile, const string& outfile, const string& bits)
return 1;
}
int wm_rc = add_stream_watermark (&in_stream, &out_stream, bits, start_pos - prev_size);
int wm_rc = add_stream_watermark (key, &in_stream, &out_stream, bits, start_pos - prev_size);
if (wm_rc != 0)
return wm_rc;
......
......@@ -20,7 +20,7 @@
#include <string>
int hls_add (const std::string& infile, const std::string& outfile, const std::string& bits);
int hls_add (const Key& key, const std::string& infile, const std::string& outfile, const std::string& bits);
int hls_prepare (const std::string& in_dir, const std::string& out_dir, const std::string& filename, const std::string& audio_master);
Error ff_decode (const std::string& filename, WavData& out_wav_data);
......
......@@ -52,7 +52,6 @@ gcrypt_init()
}
static vector<unsigned char> aes_key (16); // 128 bits
static constexpr auto GCRY_CIPHER = GCRY_CIPHER_AES128;
static void
......@@ -95,20 +94,20 @@ print (const string& label, const vector<unsigned char>& data)
}
#endif
Random::Random (uint64_t start_seed, Stream stream)
Random::Random (const Key& key, uint64_t start_seed, Stream stream)
{
gcrypt_init();
gcry_error_t gcry_ret = gcry_cipher_open (&aes_ctr_cipher, GCRY_CIPHER, GCRY_CIPHER_MODE_CTR, 0);
die_on_error ("gcry_cipher_open", gcry_ret);
gcry_ret = gcry_cipher_setkey (aes_ctr_cipher, &aes_key[0], aes_key.size());
gcry_ret = gcry_cipher_setkey (aes_ctr_cipher, key.aes_key(), Key::SIZE);
die_on_error ("gcry_cipher_setkey", gcry_ret);
gcry_ret = gcry_cipher_open (&seed_cipher, GCRY_CIPHER, GCRY_CIPHER_MODE_ECB, 0);
die_on_error ("gcry_cipher_open", gcry_ret);
gcry_ret = gcry_cipher_setkey (seed_cipher, &aes_key[0], aes_key.size());
gcry_ret = gcry_cipher_setkey (seed_cipher, key.aes_key(), Key::SIZE);
die_on_error ("gcry_cipher_setkey", gcry_ret);
seed (start_seed, stream);
......@@ -120,19 +119,19 @@ Random::seed (uint64_t seed, Stream stream)
buffer_pos = 0;
buffer.clear();
unsigned char plain_text[aes_key.size()];
unsigned char cipher_text[aes_key.size()];
unsigned char plain_text[Key::SIZE];
unsigned char cipher_text[Key::SIZE];
memset (plain_text, 0, sizeof (plain_text));
uint64_to_buffer (seed, &plain_text[0]);
plain_text[8] = uint8_t (stream);
gcry_error_t gcry_ret = gcry_cipher_encrypt (seed_cipher, &cipher_text[0], aes_key.size(),
&plain_text[0], aes_key.size());
gcry_error_t gcry_ret = gcry_cipher_encrypt (seed_cipher, &cipher_text[0], Key::SIZE,
&plain_text[0], Key::SIZE);
die_on_error ("gcry_cipher_encrypt", gcry_ret);
gcry_ret = gcry_cipher_setctr (aes_ctr_cipher, &cipher_text[0], aes_key.size());
gcry_ret = gcry_cipher_setctr (aes_ctr_cipher, &cipher_text[0], Key::SIZE);
die_on_error ("gcry_cipher_setctr", gcry_ret);
}
......@@ -172,14 +171,42 @@ Random::die_on_error (const char *func, gcry_error_t err)
}
}
string
Random::gen_key()
{
gcrypt_init();
vector<unsigned char> key (16);
gcry_randomize (&key[0], 16, /* long term key material strength */ GCRY_VERY_STRONG_RANDOM);
return vec_to_hex_str (key);
}
uint64_t
Random::seed_from_hash (const vector<float>& floats)
{
unsigned char hash[20];
gcry_md_hash_buffer (GCRY_MD_SHA1, hash, &floats[0], floats.size() * sizeof (float));
return uint64_from_buffer (hash);
}
Key::Key() :
m_aes_key (SIZE)
{
}
Key::~Key()
{
std::fill (m_aes_key.begin(), m_aes_key.end(), 0);
}
void
Random::set_global_test_key (uint64_t key)
Key::set_test_key (uint64_t key)
{
uint64_to_buffer (key, &aes_key[0]);
uint64_to_buffer (key, m_aes_key.data());
}
void
Random::load_global_key (const string& key_file)
Key::load_key (const string& key_file)
{
FILE *f = fopen (key_file.c_str(), "r");
if (!f)
......@@ -207,12 +234,12 @@ Random::load_global_key (const string& key_file)
{
/* line containing aes key */
vector<unsigned char> key = hex_str_to_vec (match[1].str());
if (key.size() != aes_key.size())
if (key.size() != Key::SIZE)
{
error ("audiowmark: wrong key length in key file '%s', line %d\n => required key length is %zd bits\n", key_file.c_str(), line, aes_key.size() * 8);
error ("audiowmark: wrong key length in key file '%s', line %d\n => required key length is %zd bits\n", key_file.c_str(), line, Key::SIZE * 8);
exit (1);
}
aes_key = key;
m_aes_key = key;
keys++;
}
else
......@@ -236,20 +263,9 @@ Random::load_global_key (const string& key_file)
}
}
string
Random::gen_key()
const unsigned char *
Key::aes_key() const
{
gcrypt_init();
vector<unsigned char> key (16);
gcry_randomize (&key[0], 16, /* long term key material strength */ GCRY_VERY_STRONG_RANDOM);
return vec_to_hex_str (key);
}
uint64_t
Random::seed_from_hash (const vector<float>& floats)
{
unsigned char hash[20];
gcry_md_hash_buffer (GCRY_MD_SHA1, hash, &floats[0], floats.size() * sizeof (float));
return uint64_from_buffer (hash);
assert (m_aes_key.size() == SIZE);
return m_aes_key.data();
}
......@@ -25,6 +25,20 @@
#include <string>
#include <random>
class Key
{
std::vector<unsigned char> m_aes_key;
public:
static constexpr size_t SIZE = 16; /* 128 bits */
Key();
~Key();
void set_test_key (uint64_t key);
void load_key (const std::string& filename);
const unsigned char *aes_key() const;
};
class Random
{
public:
......@@ -46,7 +60,7 @@ private:
void die_on_error (const char *func, gcry_error_t error);
public:
Random (uint64_t seed, Stream stream);
Random (const Key& key, uint64_t seed, Stream stream);
~Random();
typedef uint64_t result_type;
......@@ -90,8 +104,6 @@ public:
}
}
static void set_global_test_key (uint64_t seed);
static void load_global_key (const std::string& key_file);
static std::string gen_key();
static uint64_t seed_from_hash (const std::vector<float>& floats);
};
......
......@@ -27,7 +27,7 @@ using std::string;
using std::min;
void
SyncFinder::init_up_down (const WavData& wav_data, Mode mode)
SyncFinder::init_up_down (const Key& key, const WavData& wav_data, Mode mode)
{
sync_bits.clear();
......@@ -37,7 +37,7 @@ SyncFinder::init_up_down (const WavData& wav_data, Mode mode)
const int block_count = mode == Mode::CLIP ? 2 : 1;
size_t n_bands = Params::max_band - Params::min_band + 1;
UpDownGen up_down_gen (Random::Stream::sync_up_down);
UpDownGen up_down_gen (key, Random::Stream::sync_up_down);
for (int bit = 0; bit < Params::sync_bits; bit++)
{
vector<FrameBit> frame_bits;
......@@ -49,7 +49,7 @@ SyncFinder::init_up_down (const WavData& wav_data, Mode mode)
for (int block = 0; block < block_count; block++)
{
FrameBit frame_bit;
frame_bit.frame = sync_frame_pos (f + bit * Params::sync_frames_per_bit) + block * first_block_end;
frame_bit.frame = sync_frame_pos (key, f + bit * Params::sync_frames_per_bit) + block * first_block_end;
for (int ch = 0; ch < wav_data.n_channels(); ch++)
{
if (block == 0)
......@@ -254,7 +254,7 @@ SyncFinder::sync_select_n_best (vector<Score>& sync_scores, size_t n)
}
void
SyncFinder::search_refine (const WavData& wav_data, Mode mode, vector<Score>& sync_scores)
SyncFinder::search_refine (const Key& key, const WavData& wav_data, Mode mode, vector<Score>& sync_scores)
{
vector<float> fft_db;
vector<char> have_frames;
......@@ -268,9 +268,9 @@ SyncFinder::search_refine (const WavData& wav_data, Mode mode, vector<Score>& sy
vector<char> want_frames (total_frame_count);
for (size_t f = 0; f < mark_sync_frame_count(); f++)
{
want_frames[sync_frame_pos (f)] = 1;
want_frames[sync_frame_pos (key, f)] = 1;
if (mode == Mode::CLIP)
want_frames[first_block_end + sync_frame_pos (f)] = 1;
want_frames[first_block_end + sync_frame_pos (key, f)] = 1;
}
for (const auto& score : sync_scores)
......@@ -326,12 +326,12 @@ SyncFinder::fake_sync (const WavData& wav_data, Mode mode)
}
vector<SyncFinder::Score>
SyncFinder::search (const WavData& wav_data, Mode mode)
SyncFinder::search (const Key& key, const WavData& wav_data, Mode mode)
{
if (Params::test_no_sync)
return fake_sync (wav_data, mode);
init_up_down (wav_data, mode);
init_up_down (key, wav_data, mode);
if (mode == Mode::CLIP)
{
......@@ -350,15 +350,15 @@ SyncFinder::search (const WavData& wav_data, Mode mode)
if (mode == Mode::CLIP)
sync_select_n_best (sync_scores, 5);
search_refine (wav_data, mode, sync_scores);
search_refine (key, wav_data, mode, sync_scores);
return sync_scores;
}
vector<vector<SyncFinder::FrameBit>>
SyncFinder::get_sync_bits (const WavData& wav_data, Mode mode)
SyncFinder::get_sync_bits (const Key& key, const WavData& wav_data, Mode mode)
{
init_up_down (wav_data, mode);
init_up_down (key, wav_data, mode);
return sync_bits;
}
......
......@@ -20,6 +20,7 @@
#include "convcode.hh"
#include "wavdata.hh"
#include "random.hh"
/*
* The SyncFinder class searches for sync bits in an input WavData. It is used
......@@ -80,7 +81,7 @@ public:
private:
std::vector<std::vector<FrameBit>> sync_bits;
void init_up_down (const WavData& wav_data, Mode mode);
void init_up_down (const Key& key, const WavData& wav_data, Mode mode);
double sync_decode (const WavData& wav_data, const size_t start_frame,
const std::vector<float>& fft_out_db,
const std::vector<char>& have_frames,
......@@ -89,15 +90,15 @@ private:
std::vector<Score> search_approx (const WavData& wav_data, Mode mode);
void sync_select_by_threshold (std::vector<Score>& sync_scores);
void sync_select_n_best (std::vector<Score>& sync_scores, size_t n);
void search_refine (const WavData& wav_data, Mode mode, std::vector<Score>& sync_scores);
void search_refine (const Key& key, const WavData& wav_data, Mode mode, std::vector<Score>& sync_scores);
std::vector<Score> fake_sync (const WavData& wav_data, Mode mode);
// non-zero sample range: [wav_data_first, wav_data_last)
size_t wav_data_first = 0;
size_t wav_data_last = 0;
public:
std::vector<Score> search (const WavData& wav_data, Mode mode);
std::vector<std::vector<FrameBit>> get_sync_bits (const WavData& wav_data, Mode mode);
std::vector<Score> search (const Key& key, const WavData& wav_data, Mode mode);
std::vector<std::vector<FrameBit>> get_sync_bits (const Key& key, const WavData& wav_data, Mode mode);
static double bit_quality (float umag, float dmag, int bit);
static double normalize_sync_quality (double raw_quality);
......
......@@ -116,14 +116,14 @@ public:
};
int
mark_zexpand (WavData& wav_data, size_t zero_frames, const string& bits)
mark_zexpand (const Key& key, WavData& wav_data, size_t zero_frames, const string& bits)
{
WDInputStream in_stream (&wav_data);
WavData wav_data_out ({ /* no samples */ }, wav_data.n_channels(), wav_data.sample_rate(), wav_data.bit_depth());
WDOutputStream out_stream (&wav_data_out);
int rc = add_stream_watermark (&in_stream, &out_stream, bits, zero_frames);
int rc = add_stream_watermark (key, &in_stream, &out_stream, bits, zero_frames);
if (rc != 0)
return rc;
......@@ -133,7 +133,7 @@ mark_zexpand (WavData& wav_data, size_t zero_frames, const string& bits)
}
int
test_seek (const string& in, const string& out, int pos, const string& bits)
test_seek (const Key& key, const string& in, const string& out, int pos, const string& bits)
{
vector<float> samples;
WavData wav_data;
......@@ -148,7 +148,7 @@ test_seek (const string& in, const string& out, int pos, const string& bits)
samples.erase (samples.begin(), samples.begin() + pos * wav_data.n_channels());
wav_data.set_samples (samples);
int rc = mark_zexpand (wav_data, pos, bits);
int rc = mark_zexpand (key, wav_data, pos, bits);
if (rc != 0)
{
return rc;
......@@ -168,14 +168,14 @@ test_seek (const string& in, const string& out, int pos, const string& bits)
}
int
seek_perf (int sample_rate, double seconds)
seek_perf (const Key& key, int sample_rate, double seconds)
{
vector<float> samples (100);
WavData wav_data (samples, 2, sample_rate, 16);
double start_time = get_time();
int rc = mark_zexpand (wav_data, seconds * sample_rate, "0c");
int rc = mark_zexpand (key, wav_data, seconds * sample_rate, "0c");
if (rc != 0)
return rc;
......@@ -191,13 +191,14 @@ seek_perf (int sample_rate, double seconds)
int
main (int argc, char **argv)
{
Key global_key;
if (argc == 6 && strcmp (argv[1], "test-seek") == 0)
{
return test_seek (argv[2], argv[3], atoi (argv[4]), argv[5]);
return test_seek (global_key, argv[2], argv[3], atoi (argv[4]), argv[5]);
}
else if (argc == 4 && strcmp (argv[1], "seek-perf") == 0)
{
return seek_perf (atoi (argv[2]), atof (argv[3]));
return seek_perf (global_key, atoi (argv[2]), atof (argv[3]));
}
else if (argc == 4 && strcmp (argv[1], "ff-decode") == 0)
{
......
......@@ -26,7 +26,8 @@ using std::string;
int
main (int argc, char **argv)
{
Random rng (0xf00f1234b00b5678U, Random::Stream::bit_order);
Key key;
Random rng (key, 0xf00f1234b00b5678U, Random::Stream::bit_order);
for (size_t i = 0; i < 20; i++)
{
uint64_t x = rng();
......
......@@ -83,7 +83,7 @@ apply_frame_mod (const vector<FrameMod>& frame_mod, const vector<complex<float>>
}
static void
mark_data (vector<vector<FrameMod>>& frame_mod, const vector<int>& bitvec)
mark_data (const Key& key, vector<vector<FrameMod>>& frame_mod, const vector<int>& bitvec)
{
assert (bitvec.size() == mark_data_frame_count() / Params::frames_per_bit);
assert (frame_mod.size() >= mark_data_frame_count());
......@@ -92,7 +92,7 @@ mark_data (vector<vector<FrameMod>>& frame_mod, const vector<int>& bitvec)
if (Params::mix)
{
vector<MixEntry> mix_entries = gen_mix_entries();
vector<MixEntry> mix_entries = gen_mix_entries (key);
for (int f = 0; f < frame_count; f++)
{
......@@ -113,11 +113,11 @@ mark_data (vector<vector<FrameMod>>& frame_mod, const vector<int>& bitvec)
}
else
{
UpDownGen up_down_gen (Random::Stream::data_up_down);
UpDownGen up_down_gen (key, Random::Stream::data_up_down);
for (int f = 0; f < frame_count; f++)
{
size_t index = data_frame_pos (f);
size_t index = data_frame_pos (key, f);
prepare_frame_mod (up_down_gen, f, frame_mod[index], bitvec[f / Params::frames_per_bit]);
}
......@@ -125,17 +125,17 @@ mark_data (vector<vector<FrameMod>>& frame_mod, const vector<int>& bitvec)
}
static void
mark_sync (vector<vector<FrameMod>>& frame_mod, int ab)
mark_sync (const Key& key, vector<vector<FrameMod>>& frame_mod, int ab)
{
const int frame_count = mark_sync_frame_count();
assert (frame_mod.size() >= mark_sync_frame_count());
UpDownGen up_down_gen (Random::Stream::sync_up_down);
UpDownGen up_down_gen (key, Random::Stream::sync_up_down);
// sync block always written in linear order (no mix)
for (int f = 0; f < frame_count; f++)
{
size_t index = sync_frame_pos (f);
size_t index = sync_frame_pos (key, f);
int data_bit = (f / Params::sync_frames_per_bit + ab) & 1; /* write 010101 for a block, 101010 for b block */
prepare_frame_mod (up_down_gen, f, frame_mod[index], data_bit);
......@@ -143,7 +143,7 @@ mark_sync (vector<vector<FrameMod>>& frame_mod, int ab)
}
static void
init_frame_mod_vec (vector<vector<FrameMod>>& frame_mod_vec, int ab, const vector<int>& bitvec)
init_frame_mod_vec (const Key& key, vector<vector<FrameMod>>& frame_mod_vec, int ab, const vector<int>& bitvec)
{
frame_mod_vec.resize (mark_sync_frame_count() + mark_data_frame_count());
......@@ -152,10 +152,10 @@ init_frame_mod_vec (vector<vector<FrameMod>>& frame_mod_vec, int ab, const vecto
/* forward error correction */
ConvBlockType block_type = ab ? ConvBlockType::b : ConvBlockType::a;
vector<int> bitvec_fec = randomize_bit_order (code_encode (block_type, bitvec), /* encode */ true);
vector<int> bitvec_fec = randomize_bit_order (key, code_encode (block_type, bitvec), /* encode */ true);
mark_sync (frame_mod_vec, ab);
mark_data (frame_mod_vec, bitvec_fec);
mark_sync (key, frame_mod_vec, ab);
mark_data (key, frame_mod_vec, bitvec_fec);
}
/* synthesizes a watermark stream (overlap add with synthesis window)
......@@ -292,7 +292,7 @@ public:
frame_number = 2 * frames_per_block - Params::frames_pad_start;
}
vector<float>
run (const vector<float>& samples)
run (const Key& key, const vector<float>& samples)
{
assert (samples.size() == Params::frame_size * n_channels);
......@@ -302,7 +302,7 @@ public:
for (int ch = 0; ch < n_channels; ch++)
fft_delta_spect.push_back (vector<complex<float>> (fft_out.back().size()));
const vector<FrameMod>& frame_mod = get_frame_mod();
const vector<FrameMod>& frame_mod = get_frame_mod (key);
for (int ch = 0; ch < n_channels; ch++)
apply_frame_mod (frame_mod, fft_out[ch], fft_delta_spect[ch]);
......@@ -321,20 +321,20 @@ public:
return wm_synth.skip (zeros);
}
const vector<FrameMod>&
get_frame_mod()
get_frame_mod (const Key& key)
{
const size_t f = frame_number % (frames_per_block * 2);
if (f >= frames_per_block) /* B block */
{
if (frame_mod_vec_b.empty())
init_frame_mod_vec (frame_mod_vec_b, 1, bitvec);
init_frame_mod_vec (key, frame_mod_vec_b, 1, bitvec);
return frame_mod_vec_b[f - frames_per_block];
}
else /* A block */
{
if (frame_mod_vec_a.empty())
init_frame_mod_vec (frame_mod_vec_a, 0, bitvec);
init_frame_mod_vec (key, frame_mod_vec_a, 0, bitvec);
return frame_mod_vec_a[f];
}
......@@ -529,12 +529,12 @@ public:
return true;
}
vector<float>
run (const vector<float>& samples)
run (const Key& key, const vector<float>& samples)
{
if (!need_resampler)
{
/* cheap case: if no resampling is necessary, just generate the watermark signal */
return wm_gen.run (samples);
return wm_gen.run (key, samples);
}
/* resample to the watermark sample rate */
......@@ -544,7 +544,7 @@ public:
vector<float> r_samples = in_resampler->read_frames (Params::frame_size);
/* generate watermark at normalized sample rate */
vector<float> wm_samples = wm_gen.run (r_samples);
vector<float> wm_samples = wm_gen.run (key, r_samples);
/* resample back to the original sample rate of the audio file */
out_resampler->write_frames (wm_samples);
......@@ -588,7 +588,7 @@ info_format (const string& label, const RawFormat& format)
}
int
add_stream_watermark (AudioInputStream *in_stream, AudioOutputStream *out_stream, const string& bits, size_t zero_frames)
add_stream_watermark (const Key& key, AudioInputStream *in_stream, AudioOutputStream *out_stream, const string& bits, size_t zero_frames)
{
auto bitvec = parse_payload (bits);
if (bitvec.empty())
......@@ -687,7 +687,7 @@ add_stream_watermark (AudioInputStream *in_stream, AudioOutputStream *out_stream
samples.resize (Params::frame_size * n_channels);
}
audio_buffer.write_frames (samples);
samples = wm_resampler.run (samples);
samples = wm_resampler.run (key, samples);
size_t to_read = samples.size() / n_channels;
vector<float> orig_samples = audio_buffer.read_frames (to_read);
assert (samples.size() == orig_samples.size());
......@@ -760,7 +760,7 @@ add_stream_watermark (AudioInputStream *in_stream, AudioOutputStream *out_stream
}
int
add_watermark (const string& infile, const string& outfile, const string& bits)
add_watermark (const Key& key, const string& infile, const string& outfile, const string& bits)
{
/* open input stream */
Error err;
......@@ -789,7 +789,7 @@ add_watermark (const string& infile, const string& outfile, const string& bits)
if (Params::output_format == Format::RAW)
info_format ("Raw Output", Params::raw_output_format);
return add_stream_watermark (in_stream.get(), out_stream.get(), bits, 0);
return add_stream_watermark (key, in_stream.get(), out_stream.get(), bits, 0);
}
......@@ -34,7 +34,6 @@ bool Params::detect_speed = false;
bool Params::detect_speed_patient = false;
double Params::try_speed = -1;
double Params::test_speed = -1;
int Params::have_key = 0;
size_t Params::payload_size = 128;
bool Params::payload_short = false;
int Params::test_cut = 0; // for sync test
......@@ -141,7 +140,7 @@ FFTAnalyzer::fft_range (const vector<float>& samples, size_t start_index, size_t
}
int
frame_pos (int f, bool sync)
frame_pos (const Key& key, int f, bool sync)
{
static vector<int> pos_vec;
......@@ -151,7 +150,7 @@ frame_pos (int f, bool sync)
for (int i = 0; i < frame_count; i++)
pos_vec.push_back (i);
Random random (0, Random::Stream::frame_position);
Random random (key, 0, Random::Stream::frame_position);
random.shuffle (pos_vec);
}
if (sync)
......@@ -169,15 +168,15 @@ frame_pos (int f, bool sync)
}
int
sync_frame_pos (int f)
sync_frame_pos (const Key& key, int f)
{
return frame_pos (f, true);
return frame_pos (key, f, true);
}
int
data_frame_pos (int f)
data_frame_pos (const Key& key, int f)
{
return frame_pos (f, false);
return frame_pos (key, f, false);
}
size_t
......@@ -193,16 +192,16 @@ mark_sync_frame_count()
}
vector<MixEntry>
gen_mix_entries()
gen_mix_entries (const Key& key)
{
const int frame_count = mark_data_frame_count();
vector<MixEntry> mix_entries (frame_count * Params::bands_per_frame);
UpDownGen up_down_gen (Random::Stream::data_up_down);
UpDownGen up_down_gen (key, Random::Stream::data_up_down);
int entry = 0;
for (int f = 0; f < frame_count; f++)
{
const int index = data_frame_pos (f);
const int index = data_frame_pos (key, f);
UpDownArray up, down;
up_down_gen.get (f, up, down);
......@@ -210,7 +209,7 @@ gen_mix_entries()
for (size_t i = 0; i < up.size(); i++)
mix_entries[entry++] = { index, up[i], down[i] };
}
Random random (/* seed */ 0, Random::Stream::mix);
Random random (key, /* seed */ 0, Random::Stream::mix);
random.shuffle (mix_entries);
return mix_entries;
......
......@@ -45,7 +45,6 @@ public:
static bool mix;
static bool hard; // hard decode bits? (soft decoding is better)
static bool snr; // compute/show snr while adding watermark
static int have_key;
static bool detect_speed;
static bool detect_speed_patient;
......@@ -94,9 +93,9 @@ class UpDownGen
std::vector<int> bands_reorder;
public:
UpDownGen (Random::Stream random_stream) :
UpDownGen (const Key& key, Random::Stream random_stream) :
random_stream (random_stream),
random (0, random_stream),
random (key, 0, random_stream),
bands_reorder (Params::max_band - Params::min_band + 1)
{
UpDownArray x;
......@@ -141,27 +140,27 @@ struct MixEntry
int down;
};
std::vector<MixEntry> gen_mix_entries();
std::vector<MixEntry> gen_mix_entries (const Key& key);
size_t mark_data_frame_count();
size_t mark_sync_frame_count();
int frame_count (const WavData& wav_data);
int sync_frame_pos (int f);
int data_frame_pos (int f);
int sync_frame_pos (const Key& key, int f);
int data_frame_pos (const Key& key, int f);
std::vector<int> parse_payload (const std::string& str);
template<class T> std::vector<T>
randomize_bit_order (const std::vector<T>& bit_vec, bool encode)
randomize_bit_order (const Key& key, const std::vector<T>& bit_vec, bool encode)
{
std::vector<unsigned int> order;
for (size_t i = 0; i < bit_vec.size(); i++)
order.push_back (i);
Random random (/* seed */ 0, Random::Stream::bit_order);
Random random (key, /* seed */ 0, Random::Stream::bit_order);
random.shuffle (order);
std::vector<T> out_bits (bit_vec.size());
......@@ -214,8 +213,8 @@ db_from_complex (std::complex<float> f, float min_dB)
return db_from_complex (f.real(), f.imag(), min_dB);
}
int add_stream_watermark (AudioInputStream *in_stream, AudioOutputStream *out_stream, const std::string& bits, size_t zero_frames);
int add_watermark (const std::string& infile, const std::string& outfile, const std::string& bits);
int get_watermark (const std::string& infile, const std::string& orig_pattern);
int add_stream_watermark (const Key& key, AudioInputStream *in_stream, AudioOutputStream *out_stream, const std::string& bits, size_t zero_frames);
int add_watermark (const Key& key, const std::string& infile, const std::string& outfile, const std::string& bits);
int get_watermark (const Key& key, const std::string& infile, const std::string& orig_pattern);
#endif /* AUDIOWMARK_WM_COMMON_HH */
......@@ -61,13 +61,13 @@ normalize_soft_bits (const vector<float>& soft_bits)
}
static vector<float>
mix_decode (vector<vector<complex<float>>>& fft_out, int n_channels)
mix_decode (const Key& key, vector<vector<complex<float>>>& fft_out, int n_channels)
{
vector<float> raw_bit_vec;
const int frame_count = mark_data_frame_count();
vector<MixEntry> mix_entries = gen_mix_entries();
vector<MixEntry> mix_entries = gen_mix_entries (key);
double umag = 0, dmag = 0;
for (int f = 0; f < frame_count; f++)
......@@ -98,9 +98,9 @@ mix_decode (vector<vector<complex<float>>>& fft_out, int n_channels)
}
static vector<float>
linear_decode (vector<vector<complex<float>>>& fft_out, int n_channels)
linear_decode (const Key& key, vector<vector<complex<float>>>& fft_out, int n_channels)
{
UpDownGen up_down_gen (Random::Stream::data_up_down);
UpDownGen up_down_gen (key, Random::Stream::data_up_down);
vector<float> raw_bit_vec;
const int frame_count = mark_data_frame_count();
......@@ -110,7 +110,7 @@ linear_decode (vector<vector<complex<float>>>& fft_out, int n_channels)
{
for (int ch = 0; ch < n_channels; ch++)
{
const size_t index = data_frame_pos (f) * n_channels + ch;
const size_t index = data_frame_pos (key, f) * n_channels + ch;
UpDownArray up, down;
up_down_gen.get (f, up, down);
......@@ -309,12 +309,12 @@ class BlockDecoder
vector<SyncFinder::Score> sync_scores; // stored here for sync debugging
public:
void
run (const WavData& wav_data, ResultSet& result_set)
run (const Key& key, const WavData& wav_data, ResultSet& result_set)
{
int total_count = 0;
SyncFinder sync_finder;
sync_scores = sync_finder.search (wav_data, SyncFinder::Mode::BLOCK);
sync_scores = sync_finder.search (key, wav_data, SyncFinder::Mode::BLOCK);
vector<float> raw_bit_vec_all (code_size (ConvBlockType::ab, Params::payload_size));
vector<int> raw_bit_vec_norm (2);
......@@ -339,15 +339,15 @@ public:
vector<float> raw_bit_vec;
if (Params::mix)
{
raw_bit_vec = mix_decode (fft_range_out, wav_data.n_channels());
raw_bit_vec = mix_decode (key, fft_range_out, wav_data.n_channels());
}
else
{
raw_bit_vec = linear_decode (fft_range_out, wav_data.n_channels());
raw_bit_vec = linear_decode (key, fft_range_out, wav_data.n_channels());
}
assert (raw_bit_vec.size() == code_size (ConvBlockType::a, Params::payload_size));
raw_bit_vec = randomize_bit_order (raw_bit_vec, /* encode */ false);
raw_bit_vec = randomize_bit_order (key, raw_bit_vec, /* encode */ false);
/* ---- deal with this pattern ---- */
float decode_error = 0;
......@@ -466,18 +466,18 @@ class ClipDecoder
const int frames_per_block = 0;
vector<float>
mix_or_linear_decode (vector<vector<complex<float>>>& fft_out, int n_channels)
mix_or_linear_decode (const Key& key, vector<vector<complex<float>>>& fft_out, int n_channels)
{
if (Params::mix)
return mix_decode (fft_out, n_channels);
return mix_decode (key, fft_out, n_channels);
else
return linear_decode (fft_out, n_channels);
return linear_decode (key, fft_out, n_channels);
}
void
run_padded (const WavData& wav_data, ResultSet& result_set, double time_offset_sec)
run_padded (const Key& key, const WavData& wav_data, ResultSet& result_set, double time_offset_sec)
{
SyncFinder sync_finder;
vector<SyncFinder::Score> sync_scores = sync_finder.search (wav_data, SyncFinder::Mode::CLIP);
vector<SyncFinder::Score> sync_scores = sync_finder.search (key, wav_data, SyncFinder::Mode::CLIP);
FFTAnalyzer fft_analyzer (wav_data.n_channels());
for (auto sync_score : sync_scores)
......@@ -488,8 +488,8 @@ class ClipDecoder
auto fft_range_out2 = fft_analyzer.fft_range (wav_data.samples(), index + count * Params::frame_size, count);
if (fft_range_out1.size() && fft_range_out2.size())
{
const auto raw_bit_vec1 = randomize_bit_order (mix_or_linear_decode (fft_range_out1, wav_data.n_channels()), /* encode */ false);
const auto raw_bit_vec2 = randomize_bit_order (mix_or_linear_decode (fft_range_out2, wav_data.n_channels()), /* encode */ false);
const auto raw_bit_vec1 = randomize_bit_order (key, mix_or_linear_decode (key, fft_range_out1, wav_data.n_channels()), /* encode */ false);
const auto raw_bit_vec2 = randomize_bit_order (key, mix_or_linear_decode (key, fft_range_out2, wav_data.n_channels()), /* encode */ false);
const size_t bits_per_block = raw_bit_vec1.size();
vector<float> raw_bit_vec;
for (size_t i = 0; i < bits_per_block; i++)
......@@ -519,7 +519,7 @@ class ClipDecoder
}
enum class Pos { START, END };
void
run_block (const WavData& wav_data, ResultSet& result_set, Pos pos)
run_block (const Key& key, const WavData& wav_data, ResultSet& result_set, Pos pos)
{
const size_t n = (frames_per_block + 5) * Params::frame_size * wav_data.n_channels();
......@@ -561,7 +561,7 @@ class ClipDecoder
ext_samples.insert (ext_samples.end(), pad_samples_end, 0);
WavData l_wav_data (ext_samples, wav_data.n_channels(), wav_data.sample_rate(), wav_data.bit_depth());
run_padded (l_wav_data, result_set, time_offset);
run_padded (key, l_wav_data, result_set, time_offset);
}
public:
ClipDecoder() :
......@@ -569,19 +569,19 @@ public:
{
}
void
run (const WavData& wav_data, ResultSet& result_set)
run (const Key& key, const WavData& wav_data, ResultSet& result_set)
{
const int wav_frames = wav_data.n_values() / (Params::frame_size * wav_data.n_channels());
if (wav_frames < frames_per_block * 3.1) /* clip decoder is only used for small wavs */
{
run_block (wav_data, result_set, Pos::START);
run_block (wav_data, result_set, Pos::END);
run_block (key, wav_data, result_set, Pos::START);
run_block (key, wav_data, result_set, Pos::END);
}
}
};
static int
decode_and_report (const WavData& wav_data, const vector<int>& orig_bits)
decode_and_report (const Key& key, const WavData& wav_data, const vector<int>& orig_bits)
{
ResultSet result_set;
double speed = 1.0;
......@@ -598,7 +598,7 @@ decode_and_report (const WavData& wav_data, const vector<int>& orig_bits)
if (Params::detect_speed || Params::detect_speed_patient || Params::try_speed > 0)
{
if (Params::detect_speed || Params::detect_speed_patient)
speed = detect_speed (wav_data, !orig_bits.empty());
speed = detect_speed (key, wav_data, !orig_bits.empty());
else
speed = Params::try_speed;
......@@ -611,19 +611,19 @@ decode_and_report (const WavData& wav_data, const vector<int>& orig_bits)
result_set.set_speed_pattern (true);
BlockDecoder block_decoder;
block_decoder.run (wav_data_speed, result_set);
block_decoder.run (key, wav_data_speed, result_set);
ClipDecoder clip_decoder;
clip_decoder.run (wav_data_speed, result_set);
clip_decoder.run (key, wav_data_speed, result_set);
result_set.set_speed_pattern (false);
}
}
BlockDecoder block_decoder;
block_decoder.run (wav_data, result_set);
block_decoder.run (key, wav_data, result_set);
ClipDecoder clip_decoder;
clip_decoder.run (wav_data, result_set);
clip_decoder.run (key, wav_data, result_set);
result_set.sort_by_time();
......@@ -655,7 +655,7 @@ decode_and_report (const WavData& wav_data, const vector<int>& orig_bits)
}
int
get_watermark (const string& infile, const string& orig_pattern)
get_watermark (const Key& key, const string& infile, const string& orig_pattern)
{
vector<int> orig_bitvec;
if (!orig_pattern.empty())
......@@ -686,10 +686,10 @@ get_watermark (const string& infile, const string& orig_pattern)
}
if (wav_data.sample_rate() == Params::mark_sample_rate)
{
return decode_and_report (wav_data, orig_bitvec);
return decode_and_report (key, wav_data, orig_bitvec);
}
else
{
return decode_and_report (resample (wav_data, Params::mark_sample_rate), orig_bitvec);
return decode_and_report (key, resample (wav_data, Params::mark_sample_rate), orig_bitvec);
}
}
......@@ -152,7 +152,7 @@ private:
const int frames_per_block;
public:
SpeedSync (const WavData& in_data, double center) :
SpeedSync (const Key& key, const WavData& in_data, double center) :
in_data (in_data),
center (center),
frames_per_block (mark_sync_frame_count() + mark_data_frame_count())
......@@ -160,7 +160,7 @@ public:
// constructor is run in the main thread; everything that is not thread-safe must happen here
SyncFinder sync_finder;
auto sync_finder_bits = sync_finder.get_sync_bits (in_data, SyncFinder::Mode::BLOCK);
auto sync_finder_bits = sync_finder.get_sync_bits (key, in_data, SyncFinder::Mode::BLOCK);
for (size_t bit = 0; bit < sync_finder_bits.size(); bit++)
{
for (const auto& frame_bit : sync_finder_bits[bit])
......@@ -479,12 +479,12 @@ public:
printf ("range = [ %.2f .. %.2f ]\n", bound (-1), bound (1));
}
vector<SpeedSync::Score> run_search (const SpeedScanParams& scan_params, const vector<double>& speeds);
vector<SpeedSync::Score> run_search (const Key& key, const SpeedScanParams& scan_params, const vector<double>& speeds);
vector<SpeedSync::Score> refine_search (const SpeedScanParams& scan_params, double speed);
};
vector<SpeedSync::Score>
SpeedSearch::run_search (const SpeedScanParams& scan_params, const vector<double>& speeds)
SpeedSearch::run_search (const Key& key, const SpeedScanParams& scan_params, const vector<double>& speeds)
{
/* speed is between 0.8 and 1.25, so we use a clip seconds factor of 1.3 to provide enough samples */
WavData in_clip = get_speed_clip (clip_location, in_data, scan_params.seconds * 1.3);
......@@ -497,7 +497,7 @@ SpeedSearch::run_search (const SpeedScanParams& scan_params, const vector<double
{
double c_speed = speed * pow (scan_params.step, c * (scan_params.n_steps * 2 + 1));
speed_sync.push_back (std::make_unique<SpeedSync> (in_clip, c_speed));
speed_sync.push_back (std::make_unique<SpeedSync> (key, in_clip, c_speed));
}
}
......@@ -578,9 +578,9 @@ select_n_best_scores (vector<SpeedSync::Score>& scores, size_t n)
}
static vector<double>
get_clip_locations (const WavData& in_data, int n)
get_clip_locations (const Key& key, const WavData& in_data, int n)
{
Random rng (0, Random::Stream::speed_clip);
Random rng (key, 0, Random::Stream::speed_clip);
/* to improve performance, we don't hash all samples but just a few */
const vector<float>& samples = in_data.samples();
......@@ -598,13 +598,13 @@ get_clip_locations (const WavData& in_data, int n)
}
static double
get_best_clip_location (const WavData& in_data, double seconds, int candidates)
get_best_clip_location (const Key& key, const WavData& in_data, double seconds, int candidates)
{
double clip_location = 0;
double best_energy = 0;
/* try a few clip locations, use the one with highest signal energy */
for (auto location : get_clip_locations (in_data, candidates))
for (auto location : get_clip_locations (key, in_data, candidates))
{
WavData wd = get_speed_clip (location, in_data, seconds);
......@@ -621,7 +621,7 @@ get_best_clip_location (const WavData& in_data, double seconds, int candidates)
}
double
detect_speed (const WavData& in_data, bool print_results)
detect_speed (const Key& key, const WavData& in_data, bool print_results)
{
/* typically even for high strength we need at least a few seconds of audio
* in in_data for successful speed detection, but our algorithm won't work at
......@@ -674,13 +674,13 @@ detect_speed (const WavData& in_data, bool print_results)
// SpeedSearch::debug_range (scan1);
const int clip_candidates = 5;
const double clip_location = get_best_clip_location (in_data, scan1.seconds, clip_candidates);
const double clip_location = get_best_clip_location (key, in_data, scan1.seconds, clip_candidates);
vector<SpeedSync::Score> scores;
SpeedSearch speed_search (in_data, clip_location);
/* initial search using grid */
scores = speed_search.run_search (scan1, { 1.0 });
scores = speed_search.run_search (key, scan1, { 1.0 });
/* improve N best matches */
select_n_best_scores (scores, n_best);
......@@ -689,14 +689,14 @@ detect_speed (const WavData& in_data, bool print_results)
for (auto score : scores)
speeds.push_back (score.speed);
scores = speed_search.run_search (scan2, speeds);
scores = speed_search.run_search (key, scan2, speeds);
/* improve or refine best match */
select_n_best_scores (scores, 1);
if (Params::detect_speed_patient)
{
// slower version: prepare magnitudes again, according to best speed
scores = speed_search.run_search (scan3, { scores[0].speed });
scores = speed_search.run_search (key, scan3, { scores[0].speed });
}
else
{
......
......@@ -19,7 +19,8 @@
#define AUDIOWMARK_WM_SPEED_HH
#include "wavdata.hh"
#include "random.hh"
double detect_speed (const WavData& in_data, bool print_results);
double detect_speed (const Key& key, const WavData& in_data, bool print_results);
#endif /* AUDIOWMARK_WM_SPEED_HH */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment