legacy107 commited on
Commit
8e1c673
·
1 Parent(s): c462daf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -12,12 +12,12 @@ import nltk
12
  nltk.download('punkt')
13
 
14
  # Load bi encoder
15
- top_k = 10
16
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
17
 
18
  # Load your fine-tuned model and tokenizer
19
  model_name = "google/flan-t5-large"
20
- peft_name = "legacy107/flan-t5-large-ia3-newsqa"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  pretrained_model = T5ForConditionalGeneration.from_pretrained(model_name)
23
  model = T5ForConditionalGeneration.from_pretrained(model_name)
@@ -31,7 +31,7 @@ dataset = dataset.shuffle()
31
  dataset = dataset.select(range(10))
32
 
33
  # Context chunking
34
- def chunk_splitter(context, chunk_size=50, overlap=0.10):
35
  overlap_size = chunk_size * overlap
36
  sentences = nltk.sent_tokenize(context)
37
 
@@ -75,7 +75,7 @@ def retrieve_context(query, contexts):
75
  hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
76
 
77
  return " ".join(
78
- [contexts[hit["corpus_id"]] for hit in hits[0:top_k]]
79
  ).replace("\n", " ")
80
 
81
 
 
12
  nltk.download('punkt')
13
 
14
  # Load bi encoder
15
+ # top_k = 10
16
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
17
 
18
  # Load your fine-tuned model and tokenizer
19
  model_name = "google/flan-t5-large"
20
+ peft_name = "legacy107/flan-t5-large-ia3-newsqa-100"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  pretrained_model = T5ForConditionalGeneration.from_pretrained(model_name)
23
  model = T5ForConditionalGeneration.from_pretrained(model_name)
 
31
  dataset = dataset.select(range(10))
32
 
33
  # Context chunking
34
+ def chunk_splitter(context, chunk_size=100, overlap=0.10):
35
  overlap_size = chunk_size * overlap
36
  sentences = nltk.sent_tokenize(context)
37
 
 
75
  hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
76
 
77
  return " ".join(
78
+ [contexts[hit["corpus_id"]] for hit in hits]
79
  ).replace("\n", " ")
80
 
81