Commit d0c64605 authored by Stefan Westerfeld's avatar Stefan Westerfeld

Introduce SpeedSearch class.

Signed-off-by: Stefan Westerfeld's avatarStefan Westerfeld <stefan@space.twc.de>
parent e847567f
......@@ -367,8 +367,36 @@ score_average_best (const vector<SpeedSync::Score>& scores)
return best_speed;
}
static vector<SpeedSync::Score>
run_search (ThreadPool& thread_pool, vector<std::unique_ptr<SpeedSync>>& speed_sync, const SpeedScanParams& scan_params, const WavData& in_data, double clip_location, const vector<double>& speeds)
class SpeedSearch
{
ThreadPool thread_pool;
vector<std::unique_ptr<SpeedSync>> speed_sync;
const WavData& in_data;
double clip_location;
SpeedSync *
find_closest_speed_sync (double speed)
{
auto it = std::min_element (speed_sync.begin(), speed_sync.end(), [&](auto& x, auto& y)
{
return fabs (x->center_speed() - speed) < fabs (y->center_speed() - speed);
});
return (*it).get();
}
public:
SpeedSearch (const WavData& in_data, double clip_location) :
in_data (in_data),
clip_location (clip_location)
{
}
vector<SpeedSync::Score> run_search (const SpeedScanParams& scan_params, const vector<double>& speeds);
vector<SpeedSync::Score> refine_search (const SpeedScanParams& scan_params, double speed);
};
vector<SpeedSync::Score>
SpeedSearch::run_search (const SpeedScanParams& scan_params, const vector<double>& speeds)
{
/* speed is between 0.8 and 1.25, so we use a clip seconds factor of 1.3 to provide enough samples */
WavData in_clip = get_speed_clip (clip_location, in_data, scan_params.seconds * 1.3);
......@@ -403,6 +431,24 @@ run_search (ThreadPool& thread_pool, vector<std::unique_ptr<SpeedSync>>& speed_s
return scores;
}
vector<SpeedSync::Score>
SpeedSearch::refine_search (const SpeedScanParams& scan_params, double speed)
{
SpeedSync *center_speed_sync = find_closest_speed_sync (speed);
auto t0 = get_time();
center_speed_sync->start_search_jobs (thread_pool, scan_params, speed);
thread_pool.wait_all();
auto t1 = get_time();
printf ("detect_speed %.3f\n", t1 - t0);
return center_speed_sync->get_scores();
}
static void
select_n_best_scores (vector<SpeedSync::Score>& scores, size_t n)
{
......@@ -438,16 +484,6 @@ select_n_best_scores (vector<SpeedSync::Score>& scores, size_t n)
scores = lmax_scores;
}
static SpeedSync *
find_closest_speed_sync (const vector<std::unique_ptr<SpeedSync>>& speed_sync, double speed)
{
auto it = std::min_element (speed_sync.begin(), speed_sync.end(), [&](auto& x, auto& y)
{
return fabs (x->center_speed() - speed) < fabs (y->center_speed() - speed);
});
return (*it).get();
}
static vector<double>
get_clip_locations (const WavData& in_data, int n)
{
......@@ -534,9 +570,8 @@ detect_speed (const WavData& in_data, bool print_results)
const int clip_candidates = 5;
const double clip_location = get_best_clip_location (in_data, scan1.seconds, clip_candidates);
ThreadPool thread_pool;
vector<std::unique_ptr<SpeedSync>> speed_sync;
vector<SpeedSync::Score> scores;
SpeedSearch search_context (in_data, clip_location);
/* search using grid */
{
......@@ -547,43 +582,32 @@ detect_speed (const WavData& in_data, bool print_results)
speeds.push_back (pow (scan1.step, c * (scan1.n_steps * 2 + 1)));
}
scores = run_search (thread_pool, speed_sync, scan1, in_data, clip_location, speeds);
scores = search_context.run_search (scan1, speeds);
}
if (Params::detect_speed_patient)
{
select_n_best_scores (scores, 1);
scores = run_search (thread_pool, speed_sync, scan3, in_data, clip_location, { scores[0].speed });
return score_average_best (scores);
scores = search_context.run_search (scan3, { scores[0].speed });
}
else
{
/* search 5 best matches */
select_n_best_scores (scores, 5);
/* search 5 best matches */
select_n_best_scores (scores, 5);
{
vector<double> speeds;
for (auto score : scores)
speeds.push_back (score.speed);
scores = run_search (thread_pool, speed_sync, scan2, in_data, clip_location, speeds);
}
select_n_best_scores (scores, 1);
/* refine best match */
SpeedSync *center_speed_sync = find_closest_speed_sync (speed_sync, scores[0].speed);
auto t4 = get_time();
center_speed_sync->start_search_jobs (thread_pool, scan3, scores[0].speed);
thread_pool.wait_all();
vector<double> speeds;
for (auto score : scores)
speeds.push_back (score.speed);
auto t5 = get_time();
scores = search_context.run_search (scan2, speeds);
printf ("detect_speed %.3f\n", t5 - t4);
/* refine best match */
select_n_best_scores (scores, 1);
return score_average_best (center_speed_sync->get_scores());
scores = search_context.refine_search (scan3, scores[0].speed);
}
return score_average_best (scores);
#if 0
if (print_results)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment