ramy2018 commited on
Commit
4238475
·
verified ·
1 Parent(s): 5ab8d52

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +3 -4
rag_pipeline.py CHANGED
@@ -2,18 +2,17 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  from sentence_transformers import SentenceTransformer, models
3
  import numpy as np
4
  import torch
5
- import time
6
 
7
  class RAGPipeline:
8
  def __init__(self):
9
- print("[RAG] تحميل النماذج العربية...")
10
 
11
  word_embedding_model = models.Transformer('asafaya/bert-base-arabic')
12
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
13
  self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
14
 
15
- self.tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
16
- self.model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
17
 
18
  self.index = None
19
  self.chunks = []
 
2
  from sentence_transformers import SentenceTransformer, models
3
  import numpy as np
4
  import torch
 
5
 
6
  class RAGPipeline:
7
  def __init__(self):
8
+ print("[RAG] تحميل النماذج...")
9
 
10
  word_embedding_model = models.Transformer('asafaya/bert-base-arabic')
11
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
12
  self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
13
 
14
+ self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
15
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
16
 
17
  self.index = None
18
  self.chunks = []