Henry Qu commited on
Commit
0a4a4f3
·
1 Parent(s): 2696b3d
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -10,34 +10,37 @@ model = GPT2LMHeadModel.from_pretrained('gpt2')
10
  tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
11
  tokenizer.pad_token = tokenizer.eos_token
12
 
13
- temp_folder = 'temp'
14
- os.makedirs(temp_folder, exist_ok=True)
 
15
  json_file = 'index.json'
16
  with open(json_file, 'r') as file:
17
  data = json.load(file)
18
  for key, value in data.items():
19
  text_description = value['text_description']
20
- inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
21
  outputs = model(**inputs, labels=inputs["input_ids"])
22
  logits = outputs.logits
23
- torch.save(logits, os.path.join(temp_folder, f"{key}.pt"))
 
24
 
25
 
26
  def search_index(query):
27
- inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
28
  outputs = model(**inputs, labels=inputs["input_ids"])
29
 
30
  max_similarity = float('-inf')
31
  max_similarity_uuid = None
32
- for file in os.listdir(temp_folder):
33
- uuid = file.split('.')[0]
34
- logits = torch.load(os.path.join(temp_folder, file))
 
35
  similarity = (outputs.logits * logits).sum()
36
  if similarity > max_similarity:
37
  max_similarity = similarity
38
  max_similarity_uuid = uuid
39
 
40
- gr.logger.info(f"Query: {query}")
41
  return max_similarity_uuid
42
 
43
 
 
10
  tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
11
  tokenizer.pad_token = tokenizer.eos_token
12
 
13
+ # temp_folder = 'temp'
14
+ # os.makedirs(temp_folder, exist_ok=True)
15
+ logit = {}
16
  json_file = 'index.json'
17
  with open(json_file, 'r') as file:
18
  data = json.load(file)
19
  for key, value in data.items():
20
  text_description = value['text_description']
21
+ inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
22
  outputs = model(**inputs, labels=inputs["input_ids"])
23
  logits = outputs.logits
24
+ logit[key] = logits
25
+ # torch.save(logits, os.path.join(temp_folder, f"{key}.pt"))
26
 
27
 
28
  def search_index(query):
29
+ inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
30
  outputs = model(**inputs, labels=inputs["input_ids"])
31
 
32
  max_similarity = float('-inf')
33
  max_similarity_uuid = None
34
+ # for file in os.listdir(temp_folder):
35
+ # uuid = file.split('.')[0]
36
+ # logits = torch.load(os.path.join(temp_folder, file))
37
+ for uuid, logits in logit.items():
38
  similarity = (outputs.logits * logits).sum()
39
  if similarity > max_similarity:
40
  max_similarity = similarity
41
  max_similarity_uuid = uuid
42
 
43
+ gr.Info(f"Max similarity: {max_similarity}")
44
  return max_similarity_uuid
45
 
46