|
|
|
|
|
#ifndef DLIB_LIBSVM_iO_Hh_ |
|
#define DLIB_LIBSVM_iO_Hh_ |
|
|
|
#include "libsvm_io_abstract.h" |
|
|
|
#include <fstream> |
|
#include <string> |
|
#include <utility> |
|
#include "../algs.h" |
|
#include "../matrix.h" |
|
#include "../string.h" |
|
#include "../svm/sparse_vector.h" |
|
#include <vector> |
|
|
|
namespace dlib |
|
{ |
|
struct sample_data_io_error : public error |
|
{ |
|
sample_data_io_error(const std::string& message): error(message) {} |
|
}; |
|
|
|
|
|
|
|
template <typename sample_type, typename label_type, typename alloc1, typename alloc2> |
|
void load_libsvm_formatted_data ( |
|
const std::string& file_name, |
|
std::vector<sample_type, alloc1>& samples, |
|
std::vector<label_type, alloc2>& labels |
|
) |
|
{ |
|
using namespace std; |
|
typedef typename sample_type::value_type pair_type; |
|
typedef typename basic_type<typename pair_type::first_type>::type key_type; |
|
typedef typename pair_type::second_type value_type; |
|
|
|
|
|
COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value); |
|
|
|
samples.clear(); |
|
labels.clear(); |
|
|
|
ifstream fin(file_name.c_str()); |
|
|
|
if (!fin) |
|
throw sample_data_io_error("Unable to open file " + file_name); |
|
|
|
string line; |
|
istringstream sin; |
|
key_type key; |
|
value_type value; |
|
label_type label; |
|
sample_type sample; |
|
long line_num = 0; |
|
while (fin.peek() != EOF) |
|
{ |
|
++line_num; |
|
getline(fin, line); |
|
|
|
string::size_type pos = line.find_first_not_of(" \t\r\n"); |
|
|
|
|
|
if (pos == string::npos || line[pos] == '#') |
|
continue; |
|
|
|
sin.clear(); |
|
sin.str(line); |
|
sample.clear(); |
|
|
|
sin >> label; |
|
|
|
if (!sin) |
|
throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name ); |
|
|
|
|
|
sin >> ws; |
|
|
|
while (sin.peek() != EOF && sin.peek() != '#') |
|
{ |
|
|
|
sin >> key >> ws; |
|
|
|
|
|
if (sin.get() != ':') |
|
throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name); |
|
|
|
sin >> value; |
|
|
|
if (sin && value != 0) |
|
{ |
|
sample.insert(sample.end(), make_pair(key, value)); |
|
} |
|
|
|
sin >> ws; |
|
} |
|
|
|
samples.push_back(sample); |
|
labels.push_back(label); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
template <typename sample_type, typename alloc> |
|
typename enable_if<is_const_type<typename sample_type::value_type::first_type> >::type |
|
fix_nonzero_indexing ( |
|
std::vector<sample_type,alloc>& samples |
|
) |
|
{ |
|
typedef typename sample_type::value_type pair_type; |
|
typedef typename basic_type<typename pair_type::first_type>::type key_type; |
|
|
|
if (samples.size() == 0) |
|
return; |
|
|
|
|
|
key_type min_idx = samples[0].begin()->first; |
|
for (unsigned long i = 0; i < samples.size(); ++i) |
|
min_idx = std::min(min_idx, samples[i].begin()->first); |
|
|
|
|
|
if (min_idx != 0) |
|
{ |
|
sample_type temp; |
|
for (unsigned long i = 0; i < samples.size(); ++i) |
|
{ |
|
|
|
temp.clear(); |
|
typename sample_type::iterator j; |
|
for (j = samples[i].begin(); j != samples[i].end(); ++j) |
|
{ |
|
temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second)); |
|
} |
|
|
|
|
|
samples[i].swap(temp); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename sample_type, typename alloc> |
|
typename disable_if<is_const_type<typename sample_type::value_type::first_type> >::type |
|
fix_nonzero_indexing ( |
|
std::vector<sample_type,alloc>& samples |
|
) |
|
{ |
|
typedef typename sample_type::value_type pair_type; |
|
typedef typename basic_type<typename pair_type::first_type>::type key_type; |
|
|
|
if (samples.size() == 0) |
|
return; |
|
|
|
|
|
key_type min_idx = samples[0].begin()->first; |
|
for (unsigned long i = 0; i < samples.size(); ++i) |
|
min_idx = std::min(min_idx, samples[i].begin()->first); |
|
|
|
|
|
if (min_idx != 0) |
|
{ |
|
for (unsigned long i = 0; i < samples.size(); ++i) |
|
{ |
|
typename sample_type::iterator j; |
|
for (j = samples[i].begin(); j != samples[i].end(); ++j) |
|
{ |
|
j->first -= min_idx; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename sample_type, typename label_type, typename alloc1, typename alloc2> |
|
typename disable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data ( |
|
const std::string& file_name, |
|
const std::vector<sample_type, alloc1>& samples, |
|
const std::vector<label_type, alloc2>& labels |
|
) |
|
{ |
|
typedef typename sample_type::value_type pair_type; |
|
typedef typename basic_type<typename pair_type::first_type>::type key_type; |
|
|
|
|
|
COMPILE_TIME_ASSERT(is_unsigned_type<key_type>::value); |
|
|
|
|
|
DLIB_ASSERT(samples.size() == labels.size(), |
|
"\t void save_libsvm_formatted_data()" |
|
<< "\n\t You have to have labels for each sample and vice versa" |
|
<< "\n\t samples.size(): " << samples.size() |
|
<< "\n\t labels.size(): " << labels.size() |
|
); |
|
|
|
|
|
using namespace std; |
|
ofstream fout(file_name.c_str()); |
|
fout.precision(14); |
|
|
|
if (!fout) |
|
throw sample_data_io_error("Unable to open file " + file_name); |
|
|
|
for (unsigned long i = 0; i < samples.size(); ++i) |
|
{ |
|
fout << labels[i]; |
|
|
|
for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j) |
|
{ |
|
if (j->second != 0) |
|
fout << " " << j->first << ":" << j->second; |
|
} |
|
fout << "\n"; |
|
|
|
if (!fout) |
|
throw sample_data_io_error("Error while writing to file " + file_name); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
template <typename sample_type, typename label_type, typename alloc1, typename alloc2> |
|
typename enable_if<is_matrix<sample_type>,void>::type save_libsvm_formatted_data ( |
|
const std::string& file_name, |
|
const std::vector<sample_type, alloc1>& samples, |
|
const std::vector<label_type, alloc2>& labels |
|
) |
|
{ |
|
|
|
DLIB_ASSERT(samples.size() == labels.size(), |
|
"\t void save_libsvm_formatted_data()" |
|
<< "\n\t You have to have labels for each sample and vice versa" |
|
<< "\n\t samples.size(): " << samples.size() |
|
<< "\n\t labels.size(): " << labels.size() |
|
); |
|
|
|
using namespace std; |
|
ofstream fout(file_name.c_str()); |
|
fout.precision(14); |
|
|
|
if (!fout) |
|
throw sample_data_io_error("Unable to open file " + file_name); |
|
|
|
for (unsigned long i = 0; i < samples.size(); ++i) |
|
{ |
|
fout << labels[i]; |
|
|
|
for (long j = 0; j < samples[i].size(); ++j) |
|
{ |
|
if (samples[i](j) != 0) |
|
fout << " " << j << ":" << samples[i](j); |
|
} |
|
fout << "\n"; |
|
|
|
if (!fout) |
|
throw sample_data_io_error("Error while writing to file " + file_name); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
#endif |
|
|
|
|