Henry Qu commited on
Commit
a79b023
·
1 Parent(s): bb448d0
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -4,13 +4,14 @@ from huggingface_hub import hf_hub_download
4
  from pathlib import Path
5
  from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
6
  import json
 
7
 
8
  model = GPT2LMHeadModel.from_pretrained('gpt2')
9
  tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
12
- logits_dict = {}
13
-
14
  json_file = 'index.json'
15
  with open(json_file, 'r') as file:
16
  data = json.load(file)
@@ -19,7 +20,7 @@ for key, value in data.items():
19
  inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
20
  outputs = model(**inputs, labels=inputs["input_ids"])
21
  logits = outputs.logits
22
- logits_dict[key] = logits
23
 
24
 
25
  def search_index(query):
@@ -28,7 +29,9 @@ def search_index(query):
28
 
29
  max_similarity = float('-inf')
30
  max_similarity_uuid = None
31
- for uuid, logits in logits_dict.items():
 
 
32
  similarity = (outputs.logits * logits).sum()
33
  if similarity > max_similarity:
34
  max_similarity = similarity
 
4
  from pathlib import Path
5
  from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
6
  import json
7
+ import torch
8
 
9
  model = GPT2LMHeadModel.from_pretrained('gpt2')
10
  tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
11
  tokenizer.pad_token = tokenizer.eos_token
12
 
13
+ temp_folder = 'temp'
14
+ os.mkdir(temp_folder, exist_ok=True)
15
  json_file = 'index.json'
16
  with open(json_file, 'r') as file:
17
  data = json.load(file)
 
20
  inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
21
  outputs = model(**inputs, labels=inputs["input_ids"])
22
  logits = outputs.logits
23
+ torch.save(logits, os.path.join(temp_folder, f"{key}.pt"))
24
 
25
 
26
  def search_index(query):
 
29
 
30
  max_similarity = float('-inf')
31
  max_similarity_uuid = None
32
+ for file in os.listdir(temp_folder):
33
+ uuid = file.split('.')[0]
34
+ logits = torch.load(os.path.join(temp_folder, file))
35
  similarity = (outputs.logits * logits).sum()
36
  if similarity > max_similarity:
37
  max_similarity = similarity