|
#include "Prob.h" |
|
#include "Ngram.h" |
|
#include "Vocab.h" |
|
|
|
#include <sstream> |
|
#include <string> |
|
#include <iostream> |
|
#include <cstdio> |
|
#include <unistd.h> |
|
#include <sys/socket.h> |
|
#include <sys/types.h> |
|
#include <netinet/in.h> |
|
#include <netdb.h> |
|
#include <cstring> |
|
#include <map> |
|
|
|
struct Cache { |
|
map<int, Cache> tree; |
|
float prob; |
|
Cache() : prob(0) {} |
|
}; |
|
|
|
struct LMClient { |
|
Vocab* voc; |
|
int sock, port; |
|
char *s; |
|
struct hostent *hp; |
|
struct sockaddr_in server; |
|
char res[8]; |
|
|
|
LMClient(Vocab* v, const char* host) : voc(v), port(6666) { |
|
s = strchr(host, ':'); |
|
|
|
if (s != NULL) { |
|
*s = '\0'; |
|
s+=1; |
|
port = atoi(s); |
|
} |
|
|
|
sock = socket(AF_INET, SOCK_STREAM, 0); |
|
|
|
hp = gethostbyname(host); |
|
if (hp == NULL) { |
|
fprintf(stderr, "unknown host %s\n", host); |
|
exit(1); |
|
} |
|
|
|
memset(&server, '\0', sizeof(server)); |
|
memcpy((char *)&server.sin_addr, hp->h_addr, hp->h_length); |
|
server.sin_family = hp->h_addrtype; |
|
server.sin_port = htons(port); |
|
|
|
int errors = 0; |
|
while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { |
|
cerr << "Error: connect()\n"; |
|
sleep(1); |
|
errors++; |
|
if (errors > 5) exit(1); |
|
} |
|
std::cerr << "Connected to LM on " << host << " on port " << port << std::endl; |
|
} |
|
float wordProb(int word, int* context) { |
|
Cache* cur = &cache; |
|
int i = 0; |
|
while (context[i] > 0) { |
|
cur = &cur->tree[context[i++]]; |
|
} |
|
cur = &cur->tree[word]; |
|
if (cur->prob) { return cur->prob; } |
|
|
|
i = 0; |
|
ostringstream os; |
|
os << "prob " << voc->getWord((VocabIndex)word); |
|
while (context[i] > 0) { |
|
os << ' ' << voc->getWord((VocabIndex)context[i++]); |
|
} |
|
os << endl; |
|
string out = os.str(); |
|
write(sock, out.c_str(), out.size()); |
|
int r = read(sock, res, 6); |
|
int errors = 0; |
|
int cnt = 0; |
|
while (1) { |
|
if (r < 0) { |
|
errors++; sleep(1); |
|
cerr << "Error: read()\n"; |
|
if (errors > 5) exit(1); |
|
} else if (r==0 || res[cnt] == '\n') { break; } |
|
else { |
|
cnt += r; |
|
if (cnt==6) break; |
|
read(sock, &res[cnt], 6-cnt); |
|
} |
|
} |
|
cur->prob = *reinterpret_cast<float*>(res); |
|
return cur->prob; |
|
} |
|
void clear() { |
|
cache.tree.clear(); |
|
} |
|
Cache cache; |
|
}; |
|
|
|
|