Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,12 +12,12 @@ import nltk
|
|
12 |
nltk.download('punkt')
|
13 |
|
14 |
# Load bi encoder
|
15 |
-
top_k = 10
|
16 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
17 |
|
18 |
# Load your fine-tuned model and tokenizer
|
19 |
model_name = "google/flan-t5-large"
|
20 |
-
peft_name = "legacy107/flan-t5-large-ia3-newsqa"
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
pretrained_model = T5ForConditionalGeneration.from_pretrained(model_name)
|
23 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
@@ -31,7 +31,7 @@ dataset = dataset.shuffle()
|
|
31 |
dataset = dataset.select(range(10))
|
32 |
|
33 |
# Context chunking
|
34 |
-
def chunk_splitter(context, chunk_size=
|
35 |
overlap_size = chunk_size * overlap
|
36 |
sentences = nltk.sent_tokenize(context)
|
37 |
|
@@ -75,7 +75,7 @@ def retrieve_context(query, contexts):
|
|
75 |
hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
|
76 |
|
77 |
return " ".join(
|
78 |
-
[contexts[hit["corpus_id"]] for hit in hits
|
79 |
).replace("\n", " ")
|
80 |
|
81 |
|
|
|
12 |
nltk.download('punkt')
|
13 |
|
14 |
# Load bi encoder
|
15 |
+
# top_k = 10
|
16 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
17 |
|
18 |
# Load your fine-tuned model and tokenizer
|
19 |
model_name = "google/flan-t5-large"
|
20 |
+
peft_name = "legacy107/flan-t5-large-ia3-newsqa-100"
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
pretrained_model = T5ForConditionalGeneration.from_pretrained(model_name)
|
23 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
|
|
31 |
dataset = dataset.select(range(10))
|
32 |
|
33 |
# Context chunking
|
34 |
+
def chunk_splitter(context, chunk_size=100, overlap=0.10):
|
35 |
overlap_size = chunk_size * overlap
|
36 |
sentences = nltk.sent_tokenize(context)
|
37 |
|
|
|
75 |
hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
|
76 |
|
77 |
return " ".join(
|
78 |
+
[contexts[hit["corpus_id"]] for hit in hits]
|
79 |
).replace("\n", " ")
|
80 |
|
81 |
|