Henry Qu commited on
Commit
442df1d
·
1 Parent(s): 5bffb3b

modified: app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -3,14 +3,38 @@ import os
3
  from huggingface_hub import hf_hub_download
4
  from pathlib import Path
5
  from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
 
6
 
7
- config_class, model_class, tokenizer_class = GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
8
- model = model_class.from_pretrained('gpt2')
9
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def search_index(query):
13
- return "example_uuid"
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def download_video(uuid):
 
3
  from huggingface_hub import hf_hub_download
4
  from pathlib import Path
5
  from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
6
+ import json
7
 
8
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
 
9
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
 
11
+ logits_dict = {}
12
+
13
+ json_file = 'index.json'
14
+ with open(json_file, 'r') as file:
15
+ data = json.load(file)
16
+ for item in data:
17
+ uuid = item['uuid']
18
+ text_description = item['text_description']
19
+ inputs = tokenizer(text_description, return_tensors="pt", padding=True, truncation=True)
20
+ outputs = model(**inputs, labels=inputs["input_ids"])
21
+ logits = outputs.logits
22
+ logits_dict[uuid] = logits
23
+
24
 
25
  def search_index(query):
26
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
27
+ outputs = model(**inputs, labels=inputs["input_ids"])
28
+
29
+ max_similarity = float('-inf')
30
+ max_similarity_uuid = None
31
+ for uuid, logits in logits_dict.items():
32
+ similarity = (outputs.logits * logits).sum()
33
+ if similarity > max_similarity:
34
+ max_similarity = similarity
35
+ max_similarity_uuid = uuid
36
+
37
+ return max_similarity_uuid
38
 
39
 
40
  def download_video(uuid):