Spaces:
Sleeping
Sleeping
AmmarFahmy
commited on
Commit
·
de81719
1
Parent(s):
d54ea9f
from local
Browse files- model/models--sentence-transformers--all-mpnet-base-v2/.no_exist/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/added_tokens.json +0 -0
- model/models--sentence-transformers--all-mpnet-base-v2/refs/main +1 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/1_Pooling/config.json +7 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/README.md +177 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/config.json +23 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/config_sentence_transformers.json +7 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/model.safetensors +3 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/modules.json +20 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/sentence_bert_config.json +4 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/special_tokens_map.json +1 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/tokenizer.json +0 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/tokenizer_config.json +1 -0
- model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/vocab.txt +0 -0
- model_download.py +17 -0
- source.txt +1 -0
- streamlit_app.py +1012 -0
model/models--sentence-transformers--all-mpnet-base-v2/.no_exist/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/added_tokens.json
ADDED
File without changes
|
model/models--sentence-transformers--all-mpnet-base-v2/refs/main
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
84f2bcc00d77236f9e89c8a360a00fb1139bf47d
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/1_Pooling/config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
7 |
+
}
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/README.md
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
license: apache-2.0
|
4 |
+
library_name: sentence-transformers
|
5 |
+
tags:
|
6 |
+
- sentence-transformers
|
7 |
+
- feature-extraction
|
8 |
+
- sentence-similarity
|
9 |
+
- transformers
|
10 |
+
datasets:
|
11 |
+
- s2orc
|
12 |
+
- flax-sentence-embeddings/stackexchange_xml
|
13 |
+
- ms_marco
|
14 |
+
- gooaq
|
15 |
+
- yahoo_answers_topics
|
16 |
+
- code_search_net
|
17 |
+
- search_qa
|
18 |
+
- eli5
|
19 |
+
- snli
|
20 |
+
- multi_nli
|
21 |
+
- wikihow
|
22 |
+
- natural_questions
|
23 |
+
- trivia_qa
|
24 |
+
- embedding-data/sentence-compression
|
25 |
+
- embedding-data/flickr30k-captions
|
26 |
+
- embedding-data/altlex
|
27 |
+
- embedding-data/simple-wiki
|
28 |
+
- embedding-data/QQP
|
29 |
+
- embedding-data/SPECTER
|
30 |
+
- embedding-data/PAQ_pairs
|
31 |
+
- embedding-data/WikiAnswers
|
32 |
+
pipeline_tag: sentence-similarity
|
33 |
+
---
|
34 |
+
|
35 |
+
|
36 |
+
# all-mpnet-base-v2
|
37 |
+
This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.
|
38 |
+
|
39 |
+
## Usage (Sentence-Transformers)
|
40 |
+
Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
|
41 |
+
|
42 |
+
```
|
43 |
+
pip install -U sentence-transformers
|
44 |
+
```
|
45 |
+
|
46 |
+
Then you can use the model like this:
|
47 |
+
```python
|
48 |
+
from sentence_transformers import SentenceTransformer
|
49 |
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
50 |
+
|
51 |
+
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
52 |
+
embeddings = model.encode(sentences)
|
53 |
+
print(embeddings)
|
54 |
+
```
|
55 |
+
|
56 |
+
## Usage (HuggingFace Transformers)
|
57 |
+
Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
|
58 |
+
|
59 |
+
```python
|
60 |
+
from transformers import AutoTokenizer, AutoModel
|
61 |
+
import torch
|
62 |
+
import torch.nn.functional as F
|
63 |
+
|
64 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
65 |
+
def mean_pooling(model_output, attention_mask):
|
66 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
67 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
68 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
69 |
+
|
70 |
+
|
71 |
+
# Sentences we want sentence embeddings for
|
72 |
+
sentences = ['This is an example sentence', 'Each sentence is converted']
|
73 |
+
|
74 |
+
# Load model from HuggingFace Hub
|
75 |
+
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
|
76 |
+
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
|
77 |
+
|
78 |
+
# Tokenize sentences
|
79 |
+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
80 |
+
|
81 |
+
# Compute token embeddings
|
82 |
+
with torch.no_grad():
|
83 |
+
model_output = model(**encoded_input)
|
84 |
+
|
85 |
+
# Perform pooling
|
86 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
87 |
+
|
88 |
+
# Normalize embeddings
|
89 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
90 |
+
|
91 |
+
print("Sentence embeddings:")
|
92 |
+
print(sentence_embeddings)
|
93 |
+
```
|
94 |
+
|
95 |
+
## Evaluation Results
|
96 |
+
|
97 |
+
For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/all-mpnet-base-v2)
|
98 |
+
|
99 |
+
------
|
100 |
+
|
101 |
+
## Background
|
102 |
+
|
103 |
+
The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
|
104 |
+
contrastive learning objective. We used the pretrained [`microsoft/mpnet-base`](https://huggingface.co/microsoft/mpnet-base) model and fine-tuned in on a
|
105 |
+
1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
|
106 |
+
|
107 |
+
We developped this model during the
|
108 |
+
[Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
|
109 |
+
organized by Hugging Face. We developped this model as part of the project:
|
110 |
+
[Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
|
111 |
+
|
112 |
+
## Intended uses
|
113 |
+
|
114 |
+
Our model is intented to be used as a sentence and short paragraph encoder. Given an input text, it ouptuts a vector which captures
|
115 |
+
the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
|
116 |
+
|
117 |
+
By default, input text longer than 384 word pieces is truncated.
|
118 |
+
|
119 |
+
|
120 |
+
## Training procedure
|
121 |
+
|
122 |
+
### Pre-training
|
123 |
+
|
124 |
+
We use the pretrained [`microsoft/mpnet-base`](https://huggingface.co/microsoft/mpnet-base) model. Please refer to the model card for more detailed information about the pre-training procedure.
|
125 |
+
|
126 |
+
### Fine-tuning
|
127 |
+
|
128 |
+
We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
|
129 |
+
We then apply the cross entropy loss by comparing with true pairs.
|
130 |
+
|
131 |
+
#### Hyper parameters
|
132 |
+
|
133 |
+
We trained ou model on a TPU v3-8. We train the model during 100k steps using a batch size of 1024 (128 per TPU core).
|
134 |
+
We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
|
135 |
+
a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
|
136 |
+
|
137 |
+
#### Training data
|
138 |
+
|
139 |
+
We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
|
140 |
+
We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
|
141 |
+
|
142 |
+
|
143 |
+
| Dataset | Paper | Number of training tuples |
|
144 |
+
|--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
|
145 |
+
| [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
|
146 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
|
147 |
+
| [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
|
148 |
+
| [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
|
149 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
|
150 |
+
| [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
|
151 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
|
152 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title+Body, Answer) pairs | - | 21,396,559 |
|
153 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs | - | 21,396,559 |
|
154 |
+
| [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
|
155 |
+
| [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
|
156 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
|
157 |
+
| [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
|
158 |
+
| [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
|
159 |
+
| [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
|
160 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
|
161 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
|
162 |
+
| [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
|
163 |
+
| [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
|
164 |
+
| [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
|
165 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
|
166 |
+
| AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
|
167 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
|
168 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
|
169 |
+
| [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
|
170 |
+
| [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
|
171 |
+
| [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
|
172 |
+
| [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
|
173 |
+
| [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
|
174 |
+
| [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
|
175 |
+
| [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
|
176 |
+
| [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
|
177 |
+
| **Total** | | **1,170,060,424** |
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/mpnet-base",
|
3 |
+
"architectures": [
|
4 |
+
"MPNetForMaskedLM"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"max_position_embeddings": 514,
|
16 |
+
"model_type": "mpnet",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 1,
|
20 |
+
"relative_attention_num_buckets": 32,
|
21 |
+
"transformers_version": "4.8.2",
|
22 |
+
"vocab_size": 30527
|
23 |
+
}
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/config_sentence_transformers.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "2.0.0",
|
4 |
+
"transformers": "4.6.1",
|
5 |
+
"pytorch": "1.8.1"
|
6 |
+
}
|
7 |
+
}
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78c0197b6159d92658e319bc1d72e4c73a9a03dd03815e70e555c5ef05615658
|
3 |
+
size 437971872
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 384,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "[UNK]", "pad_token": "<pad>", "mask_token": "<mask>", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "microsoft/mpnet-base", "tokenizer_class": "MPNetTokenizer"}
|
model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_download.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from transformers import AutoModel, AutoTokenizer
|
2 |
+
# from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
+
|
4 |
+
# local_model_path = "models"
|
5 |
+
|
6 |
+
# # Download the model and tokenizer
|
7 |
+
# model_name = "sentence-transformers/all-mpnet-base-v2"
|
8 |
+
# AutoModel.from_pretrained(model_name, cache_dir=local_model_path)
|
9 |
+
# AutoTokenizer.from_pretrained(model_name, cache_dir=local_model_path)
|
10 |
+
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
|
13 |
+
local_model_path = "model"
|
14 |
+
|
15 |
+
# Download the model
|
16 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
17 |
+
model = SentenceTransformer(model_name, cache_folder=local_model_path)
|
source.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
https://github.com/IntelligenzaArtificiale/Free-personal-AI-Assistant-with-plugin/
|
streamlit_app.py
ADDED
@@ -0,0 +1,1012 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
import string
|
5 |
+
from zipfile import ZipFile
|
6 |
+
import streamlit as st
|
7 |
+
from streamlit_extras.colored_header import colored_header
|
8 |
+
from streamlit_extras.add_vertical_space import add_vertical_space
|
9 |
+
from hugchat import hugchat
|
10 |
+
from hugchat.login import Login
|
11 |
+
import pandas as pd
|
12 |
+
import asyncio
|
13 |
+
loop = asyncio.new_event_loop()
|
14 |
+
asyncio.set_event_loop(loop)
|
15 |
+
import sketch
|
16 |
+
from langchain.text_splitter import CharacterTextSplitter
|
17 |
+
from promptTemplate import prompt4conversation, prompt4Data, prompt4Code, prompt4Context, prompt4Audio, prompt4YT
|
18 |
+
from promptTemplate import prompt4conversationInternet
|
19 |
+
# FOR DEVELOPMENT NEW PLUGIN
|
20 |
+
# from promptTemplate import yourPLUGIN
|
21 |
+
from exportchat import export_chat
|
22 |
+
|
23 |
+
from langchain_community.vectorstores import Chroma
|
24 |
+
from langchain_community.embeddings import HuggingFaceHubEmbeddings
|
25 |
+
|
26 |
+
from langchain.chains import RetrievalQA
|
27 |
+
from HuggingChatAPI import HuggingChat
|
28 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
29 |
+
import requests
|
30 |
+
from bs4 import BeautifulSoup
|
31 |
+
import speech_recognition as sr
|
32 |
+
import pdfplumber
|
33 |
+
import docx2txt
|
34 |
+
from duckduckgo_search import DDGS
|
35 |
+
from itertools import islice
|
36 |
+
from os import path
|
37 |
+
from pydub import AudioSegment
|
38 |
+
import os
|
39 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
40 |
+
|
41 |
+
hf = None
|
42 |
+
# repo_id = "sentence-transformers/all-mpnet-base-v2"
|
43 |
+
# repo_id = "nomic-ai/nomic-embed-text-v1.5"
|
44 |
+
|
45 |
+
|
46 |
+
# if 'hf_token' in st.session_state:
|
47 |
+
# if 'hf' not in st.session_state:
|
48 |
+
# hf = HuggingFaceHubEmbeddings(
|
49 |
+
# repo_id=repo_id,
|
50 |
+
# task="feature-extraction",
|
51 |
+
# huggingfacehub_api_token=st.session_state['hf_token'],
|
52 |
+
# ) # type: ignore
|
53 |
+
# st.session_state['hf'] = hf
|
54 |
+
|
55 |
+
hf = None
|
56 |
+
local_model_path = "model"
|
57 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
58 |
+
hf = None
|
59 |
+
if 'hf' not in st.session_state:
|
60 |
+
hf = HuggingFaceEmbeddings(
|
61 |
+
model_name=model_name,
|
62 |
+
cache_folder=local_model_path,
|
63 |
+
)
|
64 |
+
st.session_state['hf'] = hf
|
65 |
+
|
66 |
+
st.set_page_config(
|
67 |
+
page_title="Talk with evrythings💬", page_icon="🤗", layout="wide", initial_sidebar_state="expanded"
|
68 |
+
)
|
69 |
+
|
70 |
+
st.markdown('<style>.css-w770g5{\
|
71 |
+
width: 100%;}\
|
72 |
+
.css-b3z5c9{ \
|
73 |
+
width: 100%;}\
|
74 |
+
.stButton>button{\
|
75 |
+
width: 100%;}\
|
76 |
+
.stDownloadButton>button{\
|
77 |
+
width: 100%;}\
|
78 |
+
</style>', unsafe_allow_html=True)
|
79 |
+
|
80 |
+
|
81 |
+
# Sidebar contents for logIN, choose plugin, and export chat
|
82 |
+
with st.sidebar:
|
83 |
+
st.title('🤗💬 PersonalChat App')
|
84 |
+
|
85 |
+
if 'hf_email' not in st.session_state or 'hf_pass' not in st.session_state:
|
86 |
+
with st.expander("ℹ️ Login in Hugging Face", expanded=True):
|
87 |
+
st.write("⚠️ You need to login in Hugging Face to use this app. You can register [here](https://huggingface.co/join).")
|
88 |
+
st.header('Hugging Face Login')
|
89 |
+
hf_email = st.text_input('Enter E-mail:')
|
90 |
+
hf_pass = st.text_input('Enter password:', type='password')
|
91 |
+
hf_token = st.text_input('Enter API Token:', type='password')
|
92 |
+
if st.button('Login 🚀') and hf_email and hf_pass and hf_token:
|
93 |
+
with st.spinner('🚀 Logging in...'):
|
94 |
+
st.session_state['hf_email'] = hf_email
|
95 |
+
st.session_state['hf_pass'] = hf_pass
|
96 |
+
st.session_state['hf_token'] = hf_token
|
97 |
+
|
98 |
+
try:
|
99 |
+
|
100 |
+
sign = Login(st.session_state['hf_email'], st.session_state['hf_pass'])
|
101 |
+
cookies = sign.login()
|
102 |
+
chatbot = hugchat.ChatBot(cookies=cookies.get_dict())
|
103 |
+
a_models = chatbot.get_available_llm_models()
|
104 |
+
# print(a_models)
|
105 |
+
except Exception as e:
|
106 |
+
st.error(e)
|
107 |
+
st.info("⚠️ Please check your credentials and try again.")
|
108 |
+
# st.error("⚠️ dont abuse the API")
|
109 |
+
st.warning("⚠️ If you don't have an account, you can register [here](https://huggingface.co/join).")
|
110 |
+
from time import sleep
|
111 |
+
sleep(3)
|
112 |
+
del st.session_state['hf_email']
|
113 |
+
del st.session_state['hf_pass']
|
114 |
+
del st.session_state['hf_token']
|
115 |
+
st.experimental_rerun()
|
116 |
+
|
117 |
+
st.session_state['chatbot'] = chatbot
|
118 |
+
|
119 |
+
id = st.session_state['chatbot'].new_conversation()
|
120 |
+
st.session_state['chatbot'].change_conversation(id)
|
121 |
+
|
122 |
+
st.session_state['conversation'] = id
|
123 |
+
# Generate empty lists for generated and past.
|
124 |
+
## generated stores AI generated responses
|
125 |
+
if 'generated' not in st.session_state:
|
126 |
+
st.session_state['generated'] = ["How may I help you ? "]
|
127 |
+
## past stores User's questions
|
128 |
+
if 'past' not in st.session_state:
|
129 |
+
st.session_state['past'] = ['Hi!']
|
130 |
+
|
131 |
+
st.session_state['LLM'] = HuggingChat(email=st.session_state['hf_email'], psw=st.session_state['hf_pass'], model=0)
|
132 |
+
|
133 |
+
st.experimental_rerun()
|
134 |
+
|
135 |
+
|
136 |
+
else:
|
137 |
+
with st.expander("ℹ️ Advanced Settings"):
|
138 |
+
#temperature: Optional[float]. Default is 0.5
|
139 |
+
#top_p: Optional[float]. Default is 0.95
|
140 |
+
#repetition_penalty: Optional[float]. Default is 1.2
|
141 |
+
#top_k: Optional[int]. Default is 50
|
142 |
+
#max_new_tokens: Optional[int]. Default is 1024
|
143 |
+
|
144 |
+
temperature = st.slider('🌡 Temperature', min_value=0.1, max_value=1.0, value=0.5, step=0.01)
|
145 |
+
top_p = st.slider('💡 Top P', min_value=0.1, max_value=1.0, value=0.95, step=0.01)
|
146 |
+
repetition_penalty = st.slider('🖌 Repetition Penalty', min_value=1.0, max_value=2.0, value=1.2, step=0.01)
|
147 |
+
top_k = st.slider('❄️ Top K', min_value=1, max_value=100, value=50, step=1)
|
148 |
+
max_new_tokens = st.slider('📝 Max New Tokens', min_value=500, max_value=4000, value=4000, step=50)
|
149 |
+
|
150 |
+
|
151 |
+
# FOR DEVELOPMENT NEW PLUGIN YOU MUST ADD IT HERE INTO THE LIST
|
152 |
+
# YOU NEED ADD THE NAME AT 144 LINE
|
153 |
+
|
154 |
+
#plugins for conversation
|
155 |
+
plugins = ["🛑 No PLUGIN","🌐 Web Search", "🔗 Talk with Website" , "📋 Talk with your DATA", "📝 Talk with your DOCUMENTS", "🎧 Talk with your AUDIO", "🎥 Talk with YT video", "🧠 GOD MODE" ,"💾 Upload saved VectorStore"]
|
156 |
+
if 'plugin' not in st.session_state:
|
157 |
+
st.session_state['plugin'] = st.selectbox('🔌 Plugins', plugins, index=0)
|
158 |
+
else:
|
159 |
+
if st.session_state['plugin'] == "🛑 No PLUGIN":
|
160 |
+
st.session_state['plugin'] = st.selectbox('🔌 Plugins', plugins, index=plugins.index(st.session_state['plugin']))
|
161 |
+
|
162 |
+
|
163 |
+
# FOR DEVELOPMENT NEW PLUGIN FOLLOW THIS TEMPLATE
|
164 |
+
# PLUGIN TEMPLATE
|
165 |
+
# if st.session_state['plugin'] == "🔌 PLUGIN NAME" and 'PLUGIN NAME' not in st.session_state:
|
166 |
+
# # PLUGIN SETTINGS
|
167 |
+
# with st.expander("🔌 PLUGIN NAME Settings", expanded=True):
|
168 |
+
# if 'PLUGIN NAME' not in st.session_state or st.session_state['PLUGIN NAME'] == False:
|
169 |
+
# # PLUGIN CODE
|
170 |
+
# st.session_state['PLUGIN NAME'] = True
|
171 |
+
# elif st.session_state['PLUGIN NAME'] == True:
|
172 |
+
# # PLUGIN CODE
|
173 |
+
# if st.button('🔌 Disable PLUGIN NAME'):
|
174 |
+
# st.session_state['plugin'] = "🛑 No PLUGIN"
|
175 |
+
# st.session_state['PLUGIN NAME'] = False
|
176 |
+
# del ALL SESSION STATE VARIABLES RELATED TO PLUGIN
|
177 |
+
# st.experimental_rerun()
|
178 |
+
# # PLUGIN UPLOADER
|
179 |
+
# if st.session_state['PLUGIN NAME'] == True:
|
180 |
+
# with st.expander("🔌 PLUGIN NAME Uploader", expanded=True):
|
181 |
+
# # PLUGIN UPLOADER CODE
|
182 |
+
# load file
|
183 |
+
# if load file and st.button('🔌 Upload PLUGIN NAME'):
|
184 |
+
# qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
185 |
+
# st.session_state['PLUGIN DB'] = qa
|
186 |
+
# st.experimental_rerun()
|
187 |
+
#
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
# WEB SEARCH PLUGIN
|
192 |
+
if st.session_state['plugin'] == "🌐 Web Search" and 'web_search' not in st.session_state:
|
193 |
+
# web search settings
|
194 |
+
with st.expander("🌐 Web Search Settings", expanded=True):
|
195 |
+
if 'web_search' not in st.session_state or st.session_state['web_search'] == False:
|
196 |
+
reg = ['us-en', 'uk-en', 'it-it']
|
197 |
+
sf = ['on', 'moderate', 'off']
|
198 |
+
tl = ['d', 'w', 'm', 'y']
|
199 |
+
if 'region' not in st.session_state:
|
200 |
+
st.session_state['region'] = st.selectbox('🗺 Region', reg, index=1)
|
201 |
+
else:
|
202 |
+
st.session_state['region'] = st.selectbox('🗺 Region', reg, index=reg.index(st.session_state['region']))
|
203 |
+
if 'safesearch' not in st.session_state:
|
204 |
+
st.session_state['safesearch'] = st.selectbox('🚨 Safe Search', sf, index=1)
|
205 |
+
else:
|
206 |
+
st.session_state['safesearch'] = st.selectbox('🚨 Safe Search', sf, index=sf.index(st.session_state['safesearch']))
|
207 |
+
if 'timelimit' not in st.session_state:
|
208 |
+
st.session_state['timelimit'] = st.selectbox('📅 Time Limit', tl, index=1)
|
209 |
+
else:
|
210 |
+
st.session_state['timelimit'] = st.selectbox('📅 Time Limit', tl, index=tl.index(st.session_state['timelimit']))
|
211 |
+
if 'max_results' not in st.session_state:
|
212 |
+
st.session_state['max_results'] = st.slider('📊 Max Results', min_value=1, max_value=5, value=2, step=1)
|
213 |
+
else:
|
214 |
+
st.session_state['max_results'] = st.slider('📊 Max Results', min_value=1, max_value=5, value=st.session_state['max_results'], step=1)
|
215 |
+
if st.button('🌐 Save change'):
|
216 |
+
st.session_state['web_search'] = "True"
|
217 |
+
st.experimental_rerun()
|
218 |
+
|
219 |
+
elif st.session_state['plugin'] == "🌐 Web Search" and st.session_state['web_search'] == 'True':
|
220 |
+
with st.expander("🌐 Web Search Settings", expanded=True):
|
221 |
+
st.write('🚀 Web Search is enabled')
|
222 |
+
st.write('🗺 Region: ', st.session_state['region'])
|
223 |
+
st.write('🚨 Safe Search: ', st.session_state['safesearch'])
|
224 |
+
st.write('📅 Time Limit: ', st.session_state['timelimit'])
|
225 |
+
if st.button('🌐🛑 Disable Web Search'):
|
226 |
+
del st.session_state['web_search']
|
227 |
+
del st.session_state['region']
|
228 |
+
del st.session_state['safesearch']
|
229 |
+
del st.session_state['timelimit']
|
230 |
+
del st.session_state['max_results']
|
231 |
+
del st.session_state['plugin']
|
232 |
+
st.experimental_rerun()
|
233 |
+
|
234 |
+
# GOD MODE PLUGIN
|
235 |
+
if st.session_state['plugin'] == "🧠 GOD MODE" and 'god_mode' not in st.session_state:
|
236 |
+
with st.expander("🧠 GOD MODE Settings", expanded=True):
|
237 |
+
if 'god_mode' not in st.session_state or st.session_state['god_mode'] == False:
|
238 |
+
topic = st.text_input('🔎 Topic', "Artificial Intelligence in Finance")
|
239 |
+
web_result = st.checkbox('🌐 Web Search', value=True, disabled=True)
|
240 |
+
yt_result = st.checkbox('🎥 YT Search', value=True, disabled=True)
|
241 |
+
website_result = st.checkbox('🔗 Website Search', value=True, disabled=True)
|
242 |
+
deep_of_search = st.slider('📊 Deep of Search', min_value=1, max_value=5, value=2, step=1)
|
243 |
+
if st.button('🧠✅ Give knowledge to the model'):
|
244 |
+
full_text = []
|
245 |
+
links = []
|
246 |
+
news = []
|
247 |
+
yt_ids = []
|
248 |
+
source = []
|
249 |
+
if web_result == True:
|
250 |
+
internet_result = ""
|
251 |
+
internet_answer = ""
|
252 |
+
with DDGS() as ddgs:
|
253 |
+
with st.spinner('🌐 Searching on the web...'):
|
254 |
+
ddgs_gen = ddgs.text(topic, region="us-en")
|
255 |
+
for r in islice(ddgs_gen, deep_of_search):
|
256 |
+
l = r['href']
|
257 |
+
source.append(l)
|
258 |
+
links.append(l)
|
259 |
+
internet_result += str(r) + "\n\n"
|
260 |
+
|
261 |
+
fast_answer = ddgs.news(topic)
|
262 |
+
for r in islice(fast_answer, deep_of_search):
|
263 |
+
internet_answer += str(r) + "\n\n"
|
264 |
+
l = r['url']
|
265 |
+
source.append(l)
|
266 |
+
news.append(r)
|
267 |
+
|
268 |
+
|
269 |
+
full_text.append(internet_result)
|
270 |
+
full_text.append(internet_answer)
|
271 |
+
|
272 |
+
if yt_result == True:
|
273 |
+
with st.spinner('🎥 Searching on YT...'):
|
274 |
+
from youtubesearchpython import VideosSearch
|
275 |
+
videosSearch = VideosSearch(topic, limit = deep_of_search)
|
276 |
+
yt_result = videosSearch.result()
|
277 |
+
for i in yt_result['result']: # type: ignore
|
278 |
+
duration = i['duration'] # type: ignore
|
279 |
+
duration = duration.split(':')
|
280 |
+
if len(duration) == 3:
|
281 |
+
#skip videos longer than 1 hour
|
282 |
+
if int(duration[0]) > 1:
|
283 |
+
continue
|
284 |
+
if len(duration) == 2:
|
285 |
+
#skip videos longer than 30 minutes
|
286 |
+
if int(duration[0]) > 30:
|
287 |
+
continue
|
288 |
+
yt_ids.append(i['id']) # type: ignore
|
289 |
+
source.append("https://www.youtube.com/watch?v="+i['id']) # type: ignore
|
290 |
+
full_text.append(i['title']) # type: ignore
|
291 |
+
|
292 |
+
|
293 |
+
if website_result == True:
|
294 |
+
for l in links:
|
295 |
+
try:
|
296 |
+
with st.spinner(f'👨💻 Scraping website : {l}'):
|
297 |
+
r = requests.get(l)
|
298 |
+
soup = BeautifulSoup(r.content, 'html.parser')
|
299 |
+
full_text.append(soup.get_text()+"\n\n")
|
300 |
+
except:
|
301 |
+
pass
|
302 |
+
|
303 |
+
for id in yt_ids:
|
304 |
+
try:
|
305 |
+
yt_video_txt= []
|
306 |
+
with st.spinner(f'👨💻 Scraping YT video : {id}'):
|
307 |
+
transcript_list = YouTubeTranscriptApi.list_transcripts(id)
|
308 |
+
transcript_en = None
|
309 |
+
last_language = ""
|
310 |
+
for transcript in transcript_list:
|
311 |
+
if transcript.language_code == 'en':
|
312 |
+
transcript_en = transcript
|
313 |
+
break
|
314 |
+
else:
|
315 |
+
last_language = transcript.language_code
|
316 |
+
if transcript_en is None:
|
317 |
+
transcript_en = transcript_list.find_transcript([last_language])
|
318 |
+
transcript_en = transcript_en.translate('en')
|
319 |
+
|
320 |
+
text = transcript_en.fetch()
|
321 |
+
yt_video_txt.append(text)
|
322 |
+
|
323 |
+
for i in range(len(yt_video_txt)):
|
324 |
+
for j in range(len(yt_video_txt[i])):
|
325 |
+
full_text.append(yt_video_txt[i][j]['text'])
|
326 |
+
|
327 |
+
|
328 |
+
except:
|
329 |
+
pass
|
330 |
+
|
331 |
+
with st.spinner('🧠 Building vectorstore with knowledge...'):
|
332 |
+
full_text = "\n".join(full_text)
|
333 |
+
st.session_state['god_text'] = [full_text]
|
334 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
|
335 |
+
texts = text_splitter.create_documents([full_text])
|
336 |
+
# Select embeddings
|
337 |
+
embeddings = st.session_state['hf']
|
338 |
+
# Create a vectorstore from documents
|
339 |
+
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
|
340 |
+
db = Chroma.from_documents(texts, embeddings, persist_directory="./chroma_db_" + random_str)
|
341 |
+
|
342 |
+
with st.spinner('🔨 Saving vectorstore...'):
|
343 |
+
# save vectorstore
|
344 |
+
db.persist()
|
345 |
+
#create .zip file of directory to download
|
346 |
+
shutil.make_archive("./chroma_db_" + random_str, 'zip', "./chroma_db_" + random_str)
|
347 |
+
# save in session state and download
|
348 |
+
st.session_state['db'] = "./chroma_db_" + random_str + ".zip"
|
349 |
+
|
350 |
+
with st.spinner('🔨 Creating QA chain...'):
|
351 |
+
# Create retriever interface
|
352 |
+
retriever = db.as_retriever()
|
353 |
+
# Create QA chain
|
354 |
+
qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
355 |
+
st.session_state['god_mode'] = qa
|
356 |
+
st.session_state['god_mode_source'] = source
|
357 |
+
st.session_state['god_mode_info'] = "🧠 GOD MODE have builded a vectorstore about **" + topic + f"**. The knowledge is based on\n- {len(news)} news🗞\n- {len(yt_ids)} YT videos📺\n- {len(links)} websites🌐 \n"
|
358 |
+
|
359 |
+
st.experimental_rerun()
|
360 |
+
|
361 |
+
|
362 |
+
if st.session_state['plugin'] == "🧠 GOD MODE" and 'god_mode' in st.session_state:
|
363 |
+
with st.expander("**✅ GOD MODE is enabled 🚀**", expanded=True):
|
364 |
+
st.markdown(st.session_state['god_mode_info'])
|
365 |
+
if 'db' in st.session_state:
|
366 |
+
# leave ./ from name for download
|
367 |
+
file_name = st.session_state['db'][2:]
|
368 |
+
st.download_button(
|
369 |
+
label="📩 Download vectorstore",
|
370 |
+
data=open(file_name, 'rb').read(),
|
371 |
+
file_name=file_name,
|
372 |
+
mime='application/zip'
|
373 |
+
)
|
374 |
+
if st.button('🧠🛑 Disable GOD MODE'):
|
375 |
+
del st.session_state['god_mode']
|
376 |
+
del st.session_state['db']
|
377 |
+
del st.session_state['god_text']
|
378 |
+
del st.session_state['god_mode_info']
|
379 |
+
del st.session_state['god_mode_source']
|
380 |
+
del st.session_state['plugin']
|
381 |
+
st.experimental_rerun()
|
382 |
+
|
383 |
+
|
384 |
+
# DATA PLUGIN
|
385 |
+
if st.session_state['plugin'] == "📋 Talk with your DATA" and 'df' not in st.session_state:
|
386 |
+
with st.expander("📋 Talk with your DATA", expanded= True):
|
387 |
+
upload_csv = st.file_uploader("Upload your CSV", type=['csv'])
|
388 |
+
if upload_csv is not None:
|
389 |
+
df = pd.read_csv(upload_csv)
|
390 |
+
st.session_state['df'] = df
|
391 |
+
st.experimental_rerun()
|
392 |
+
if st.session_state['plugin'] == "📋 Talk with your DATA":
|
393 |
+
if st.button('🛑📋 Remove DATA from context'):
|
394 |
+
if 'df' in st.session_state:
|
395 |
+
del st.session_state['df']
|
396 |
+
del st.session_state['plugin']
|
397 |
+
st.experimental_rerun()
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
# DOCUMENTS PLUGIN
|
402 |
+
if st.session_state['plugin'] == "📝 Talk with your DOCUMENTS" and 'documents' not in st.session_state:
|
403 |
+
with st.expander("📝 Talk with your DOCUMENT", expanded=True):
|
404 |
+
upload_pdf = st.file_uploader("Upload your DOCUMENT", type=['txt', 'pdf', 'docx'], accept_multiple_files=True)
|
405 |
+
if upload_pdf is not None and st.button('📝✅ Load Documents'):
|
406 |
+
documents = []
|
407 |
+
with st.spinner('🔨 Reading documents...'):
|
408 |
+
for upload_pdf in upload_pdf:
|
409 |
+
print(upload_pdf.type)
|
410 |
+
if upload_pdf.type == 'text/plain':
|
411 |
+
documents += [upload_pdf.read().decode()]
|
412 |
+
elif upload_pdf.type == 'application/pdf':
|
413 |
+
with pdfplumber.open(upload_pdf) as pdf:
|
414 |
+
documents += [page.extract_text() for page in pdf.pages]
|
415 |
+
elif upload_pdf.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
416 |
+
text = docx2txt.process(upload_pdf)
|
417 |
+
documents += [text]
|
418 |
+
st.session_state['documents'] = documents
|
419 |
+
# Split documents into chunks
|
420 |
+
with st.spinner('🔨 Creating vectorstore...'):
|
421 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
|
422 |
+
texts = text_splitter.create_documents(documents)
|
423 |
+
# Select embeddings
|
424 |
+
embeddings = st.session_state['hf']
|
425 |
+
# Create a vectorstore from documents
|
426 |
+
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
|
427 |
+
db = Chroma.from_documents(texts, embeddings, persist_directory="./chroma_db_" + random_str)
|
428 |
+
|
429 |
+
with st.spinner('🔨 Saving vectorstore...'):
|
430 |
+
# save vectorstore
|
431 |
+
db.persist()
|
432 |
+
#create .zip file of directory to download
|
433 |
+
shutil.make_archive("./chroma_db_" + random_str, 'zip', "./chroma_db_" + random_str)
|
434 |
+
# save in session state and download
|
435 |
+
st.session_state['db'] = "./chroma_db_" + random_str + ".zip"
|
436 |
+
|
437 |
+
with st.spinner('🔨 Creating QA chain...'):
|
438 |
+
# Create retriever interface
|
439 |
+
retriever = db.as_retriever()
|
440 |
+
# Create QA chain
|
441 |
+
qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
442 |
+
st.session_state['pdf'] = qa
|
443 |
+
|
444 |
+
st.experimental_rerun()
|
445 |
+
|
446 |
+
if st.session_state['plugin'] == "📝 Talk with your DOCUMENTS":
|
447 |
+
if 'db' in st.session_state:
|
448 |
+
# leave ./ from name for download
|
449 |
+
file_name = st.session_state['db'][2:]
|
450 |
+
st.download_button(
|
451 |
+
label="📩 Download vectorstore",
|
452 |
+
data=open(file_name, 'rb').read(),
|
453 |
+
file_name=file_name,
|
454 |
+
mime='application/zip'
|
455 |
+
)
|
456 |
+
if st.button('🛑📝 Remove PDF from context'):
|
457 |
+
if 'pdf' in st.session_state:
|
458 |
+
del st.session_state['db']
|
459 |
+
del st.session_state['pdf']
|
460 |
+
del st.session_state['documents']
|
461 |
+
del st.session_state['plugin']
|
462 |
+
|
463 |
+
st.experimental_rerun()
|
464 |
+
|
465 |
+
# AUDIO PLUGIN
|
466 |
+
if st.session_state['plugin'] == "🎧 Talk with your AUDIO" and 'audio' not in st.session_state:
|
467 |
+
with st.expander("🎙 Talk with your AUDIO", expanded=True):
|
468 |
+
f = st.file_uploader("Upload your AUDIO", type=['wav', 'mp3'])
|
469 |
+
if f is not None:
|
470 |
+
if f.type == 'audio/mpeg':
|
471 |
+
#convert mp3 to wav
|
472 |
+
with st.spinner('🔨 Converting mp3 to wav...'):
|
473 |
+
#save mp3
|
474 |
+
with open('audio.mp3', 'wb') as out:
|
475 |
+
out.write(f.read())
|
476 |
+
#convert to wav
|
477 |
+
sound = AudioSegment.from_mp3("audio.mp3")
|
478 |
+
sound.export("audio.wav", format="wav")
|
479 |
+
file_name = 'audio.wav'
|
480 |
+
else:
|
481 |
+
with open(f.name, 'wb') as out:
|
482 |
+
out.write(f.read())
|
483 |
+
|
484 |
+
bytes_data = f.read()
|
485 |
+
file_name = f.name
|
486 |
+
|
487 |
+
r = sr.Recognizer()
|
488 |
+
#Given audio file must be a filename string or a file-like object
|
489 |
+
|
490 |
+
|
491 |
+
with st.spinner('🔨 Reading audio...'):
|
492 |
+
with sr.AudioFile(file_name) as source:
|
493 |
+
# listen for the data (load audio to memory)
|
494 |
+
audio_data = r.record(source)
|
495 |
+
# recognize (convert from speech to text)
|
496 |
+
text = r.recognize_google(audio_data)
|
497 |
+
data = [text]
|
498 |
+
# data = query(bytes_data)
|
499 |
+
with st.spinner('🎙 Creating Vectorstore...'):
|
500 |
+
|
501 |
+
#split text into chunks
|
502 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
|
503 |
+
texts = text_splitter.create_documents(text)
|
504 |
+
|
505 |
+
embeddings = st.session_state['hf']
|
506 |
+
# Create a vectorstore from documents
|
507 |
+
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
|
508 |
+
db = Chroma.from_documents(texts, embeddings, persist_directory="./chroma_db_" + random_str)
|
509 |
+
# save vectorstore
|
510 |
+
|
511 |
+
with st.spinner('🎙 Saving Vectorstore...'):
|
512 |
+
db.persist()
|
513 |
+
#create .zip file of directory to download
|
514 |
+
shutil.make_archive("./chroma_db_" + random_str, 'zip', "./chroma_db_" + random_str)
|
515 |
+
# save in session state and download
|
516 |
+
st.session_state['db'] = "./chroma_db_" + random_str + ".zip"
|
517 |
+
|
518 |
+
with st.spinner('🎙 Creating QA chain...'):
|
519 |
+
# Create retriever interface
|
520 |
+
retriever = db.as_retriever()
|
521 |
+
# Create QA chain
|
522 |
+
qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
523 |
+
st.session_state['audio'] = qa
|
524 |
+
st.session_state['audio_text'] = text
|
525 |
+
st.experimental_rerun()
|
526 |
+
|
527 |
+
if st.session_state['plugin'] == "🎧 Talk with your AUDIO":
|
528 |
+
if 'db' in st.session_state:
|
529 |
+
# leave ./ from name for download
|
530 |
+
file_name = st.session_state['db'][2:]
|
531 |
+
st.download_button(
|
532 |
+
label="📩 Download vectorstore",
|
533 |
+
data=open(file_name, 'rb').read(),
|
534 |
+
file_name=file_name,
|
535 |
+
mime='application/zip'
|
536 |
+
)
|
537 |
+
if st.button('🛑🎙 Remove AUDIO from context'):
|
538 |
+
if 'audio' in st.session_state:
|
539 |
+
del st.session_state['db']
|
540 |
+
del st.session_state['audio']
|
541 |
+
del st.session_state['audio_text']
|
542 |
+
del st.session_state['plugin']
|
543 |
+
st.experimental_rerun()
|
544 |
+
|
545 |
+
|
546 |
+
# YT PLUGIN
|
547 |
+
if st.session_state['plugin'] == "🎥 Talk with YT video" and 'yt' not in st.session_state:
|
548 |
+
with st.expander("🎥 Talk with YT video", expanded=True):
|
549 |
+
yt_url = st.text_input("1.📺 Enter a YouTube URL")
|
550 |
+
yt_url2 = st.text_input("2.📺 Enter a YouTube URL")
|
551 |
+
yt_url3 = st.text_input("3.📺 Enter a YouTube URL")
|
552 |
+
if yt_url is not None and st.button('🎥✅ Add YouTube video to context'):
|
553 |
+
if yt_url != "":
|
554 |
+
video = 1
|
555 |
+
yt_url = yt_url.split("=")[1]
|
556 |
+
if yt_url2 != "":
|
557 |
+
yt_url2 = yt_url2.split("=")[1]
|
558 |
+
video = 2
|
559 |
+
if yt_url3 != "":
|
560 |
+
yt_url3 = yt_url3.split("=")[1]
|
561 |
+
video = 3
|
562 |
+
|
563 |
+
text_yt = []
|
564 |
+
text_list = []
|
565 |
+
for i in range(video):
|
566 |
+
with st.spinner(f'🎥 Extracting TEXT from YouTube video {str(i)} ...'):
|
567 |
+
#get en subtitles
|
568 |
+
transcript_list = YouTubeTranscriptApi.list_transcripts(yt_url)
|
569 |
+
transcript_en = None
|
570 |
+
last_language = ""
|
571 |
+
for transcript in transcript_list:
|
572 |
+
if transcript.language_code == 'en':
|
573 |
+
transcript_en = transcript
|
574 |
+
break
|
575 |
+
else:
|
576 |
+
last_language = transcript.language_code
|
577 |
+
if transcript_en is None:
|
578 |
+
transcript_en = transcript_list.find_transcript([last_language])
|
579 |
+
transcript_en = transcript_en.translate('en')
|
580 |
+
|
581 |
+
text = transcript_en.fetch()
|
582 |
+
text_yt.append(text)
|
583 |
+
|
584 |
+
for i in range(len(text_yt)):
|
585 |
+
for j in range(len(text_yt[i])):
|
586 |
+
text_list.append(text_yt[i][j]['text'])
|
587 |
+
|
588 |
+
# creating a vectorstore
|
589 |
+
|
590 |
+
with st.spinner('🎥 Creating Vectorstore...'):
|
591 |
+
text_splitter = CharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
|
592 |
+
texts = text_splitter.create_documents(text_list)
|
593 |
+
# Select embeddings
|
594 |
+
embeddings = st.session_state['hf']
|
595 |
+
# Create a vectorstore from documents
|
596 |
+
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
|
597 |
+
db = Chroma.from_documents(texts, embeddings, persist_directory="./chroma_db_" + random_str)
|
598 |
+
|
599 |
+
with st.spinner('🎥 Saving Vectorstore...'):
|
600 |
+
# save vectorstore
|
601 |
+
db.persist()
|
602 |
+
#create .zip file of directory to download
|
603 |
+
shutil.make_archive("./chroma_db_" + random_str, 'zip', "./chroma_db_" + random_str)
|
604 |
+
# save in session state and download
|
605 |
+
st.session_state['db'] = "./chroma_db_" + random_str + ".zip"
|
606 |
+
|
607 |
+
with st.spinner('🎥 Creating QA chain...'):
|
608 |
+
# Create retriever interface
|
609 |
+
retriever = db.as_retriever(search_type="mmr", search_kwargs={'k': 100, 'lambda_mult': 0.25})
|
610 |
+
# Create QA chain
|
611 |
+
qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
612 |
+
st.session_state['yt'] = qa
|
613 |
+
st.session_state['yt_text'] = text_list
|
614 |
+
st.experimental_rerun()
|
615 |
+
|
616 |
+
if st.session_state['plugin'] == "🎥 Talk with YT video":
|
617 |
+
if 'db' in st.session_state:
|
618 |
+
# leave ./ from name for download
|
619 |
+
file_name = st.session_state['db'][2:]
|
620 |
+
st.download_button(
|
621 |
+
label="📩 Download vectorstore",
|
622 |
+
data=open(file_name, 'rb').read(),
|
623 |
+
file_name=file_name,
|
624 |
+
mime='application/zip'
|
625 |
+
)
|
626 |
+
|
627 |
+
if st.button('🛑🎥 Remove YT video from context'):
|
628 |
+
if 'yt' in st.session_state:
|
629 |
+
del st.session_state['db']
|
630 |
+
del st.session_state['yt']
|
631 |
+
del st.session_state['yt_text']
|
632 |
+
del st.session_state['plugin']
|
633 |
+
st.experimental_rerun()
|
634 |
+
|
635 |
+
# WEBSITE PLUGIN
|
636 |
+
if st.session_state['plugin'] == "🔗 Talk with Website" and 'web_sites' not in st.session_state:
|
637 |
+
with st.expander("🔗 Talk with Website", expanded=True):
|
638 |
+
web_url = st.text_area("🔗 Enter a website URLs , one for each line")
|
639 |
+
if web_url is not None and st.button('🔗✅ Add website to context'):
|
640 |
+
if web_url != "":
|
641 |
+
text = []
|
642 |
+
#max 10 websites
|
643 |
+
with st.spinner('🔗 Extracting TEXT from Websites ...'):
|
644 |
+
for url in web_url.split("\n")[:10]:
|
645 |
+
page = requests.get(url)
|
646 |
+
soup = BeautifulSoup(page.content, 'html.parser')
|
647 |
+
text.append(soup.get_text())
|
648 |
+
# creating a vectorstore
|
649 |
+
|
650 |
+
with st.spinner('🔗 Creating Vectorstore...'):
|
651 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
|
652 |
+
texts = text_splitter.create_documents(text)
|
653 |
+
# Select embeddings
|
654 |
+
embeddings = st.session_state['hf']
|
655 |
+
# Create a vectorstore from documents
|
656 |
+
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
|
657 |
+
db = Chroma.from_documents(texts, embeddings, persist_directory="./chroma_db_" + random_str)
|
658 |
+
|
659 |
+
with st.spinner('🔗 Saving Vectorstore...'):
|
660 |
+
# save vectorstore
|
661 |
+
db.persist()
|
662 |
+
#create .zip file of directory to download
|
663 |
+
shutil.make_archive("./chroma_db_" + random_str, 'zip', "./chroma_db_" + random_str)
|
664 |
+
# save in session state and download
|
665 |
+
st.session_state['db'] = "./chroma_db_" + random_str + ".zip"
|
666 |
+
|
667 |
+
with st.spinner('🔗 Creating QA chain...'):
|
668 |
+
# Create retriever interface
|
669 |
+
retriever = db.as_retriever()
|
670 |
+
# Create QA chain
|
671 |
+
qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
672 |
+
st.session_state['web_sites'] = qa
|
673 |
+
st.session_state['web_text'] = text
|
674 |
+
st.experimental_rerun()
|
675 |
+
|
676 |
+
if st.session_state['plugin'] == "🔗 Talk with Website":
|
677 |
+
if 'db' in st.session_state:
|
678 |
+
# leave ./ from name for download
|
679 |
+
file_name = st.session_state['db'][2:]
|
680 |
+
st.download_button(
|
681 |
+
label="📩 Download vectorstore",
|
682 |
+
data=open(file_name, 'rb').read(),
|
683 |
+
file_name=file_name,
|
684 |
+
mime='application/zip'
|
685 |
+
)
|
686 |
+
|
687 |
+
if st.button('🛑🔗 Remove Website from context'):
|
688 |
+
if 'web_sites' in st.session_state:
|
689 |
+
del st.session_state['db']
|
690 |
+
del st.session_state['web_sites']
|
691 |
+
del st.session_state['web_text']
|
692 |
+
del st.session_state['plugin']
|
693 |
+
st.experimental_rerun()
|
694 |
+
|
695 |
+
|
696 |
+
# UPLOAD PREVIUS VECTORSTORE
|
697 |
+
if st.session_state['plugin'] == "💾 Upload saved VectorStore" and 'old_db' not in st.session_state:
|
698 |
+
with st.expander("💾 Upload saved VectorStore", expanded=True):
|
699 |
+
db_file = st.file_uploader("Upload a saved VectorStore", type=["zip"])
|
700 |
+
if db_file is not None and st.button('✅💾 Add saved VectorStore to context'):
|
701 |
+
if db_file != "":
|
702 |
+
with st.spinner('💾 Extracting VectorStore...'):
|
703 |
+
# unzip file in a new directory
|
704 |
+
with ZipFile(db_file, 'r') as zipObj:
|
705 |
+
# Extract all the contents of zip file in different directory
|
706 |
+
random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
|
707 |
+
zipObj.extractall("chroma_db_" + random_str)
|
708 |
+
# save in session state the path of the directory
|
709 |
+
st.session_state['old_db'] = "chroma_db_" + random_str
|
710 |
+
hf = st.session_state['hf']
|
711 |
+
# Create retriever interface
|
712 |
+
db = Chroma("chroma_db_" + random_str, embedding_function=hf)
|
713 |
+
|
714 |
+
with st.spinner('💾 Creating QA chain...'):
|
715 |
+
retriever = db.as_retriever()
|
716 |
+
# Create QA chain
|
717 |
+
qa = RetrievalQA.from_chain_type(llm=st.session_state['LLM'], chain_type='stuff', retriever=retriever, return_source_documents=True)
|
718 |
+
st.session_state['old_db'] = qa
|
719 |
+
st.experimental_rerun()
|
720 |
+
|
721 |
+
if st.session_state['plugin'] == "💾 Upload saved VectorStore":
|
722 |
+
if st.button('🛑💾 Remove VectorStore from context'):
|
723 |
+
if 'old_db' in st.session_state:
|
724 |
+
del st.session_state['old_db']
|
725 |
+
del st.session_state['plugin']
|
726 |
+
st.experimental_rerun()
|
727 |
+
|
728 |
+
|
729 |
+
# END OF PLUGIN
|
730 |
+
add_vertical_space(4)
|
731 |
+
if 'hf_email' in st.session_state:
|
732 |
+
if st.button('🗑 Logout'):
|
733 |
+
keys = list(st.session_state.keys())
|
734 |
+
for key in keys:
|
735 |
+
del st.session_state[key]
|
736 |
+
st.experimental_rerun()
|
737 |
+
|
738 |
+
export_chat()
|
739 |
+
# add_vertical_space(5)
|
740 |
+
# html_chat = '<center><h6>🤗 Support the project with a donation for the development of new features 🤗</h6>'
|
741 |
+
# html_chat += '<br><a href="https://rebrand.ly/SupportAUTOGPTfree"><img src="https://www.paypalobjects.com/en_US/i/btn/btn_donateCC_LG.gif" alt="PayPal donate button" /></a><center><br>'
|
742 |
+
# st.markdown(html_chat, unsafe_allow_html=True)
|
743 |
+
# st.write('Made with ❤️ by [Alessandro CIciarelli](https://intelligenzaartificialeitalia.net)')
|
744 |
+
|
745 |
+
##### End of sidebar
|
746 |
+
|
747 |
+
|
748 |
+
# User input
|
749 |
+
# Layout of input/response containers
|
750 |
+
input_container = st.container()
|
751 |
+
response_container = st.container()
|
752 |
+
data_view_container = st.container()
|
753 |
+
loading_container = st.container()
|
754 |
+
|
755 |
+
|
756 |
+
|
757 |
+
## Applying the user input box
|
758 |
+
with input_container:
|
759 |
+
input_text = st.chat_input("🧑💻 Write here 👇", key="input")
|
760 |
+
|
761 |
+
with data_view_container:
|
762 |
+
if 'df' in st.session_state:
|
763 |
+
with st.expander("🤖 View your **DATA**"):
|
764 |
+
st.data_editor(st.session_state['df'], use_container_width=True)
|
765 |
+
if 'pdf' in st.session_state:
|
766 |
+
with st.expander("🤖 View your **DOCUMENTs**"):
|
767 |
+
st.write(st.session_state['documents'])
|
768 |
+
if 'audio' in st.session_state:
|
769 |
+
with st.expander("🤖 View your **AUDIO**"):
|
770 |
+
st.write(st.session_state['audio_text'])
|
771 |
+
if 'yt' in st.session_state:
|
772 |
+
with st.expander("🤖 View your **YT video**"):
|
773 |
+
st.write(st.session_state['yt_text'])
|
774 |
+
if 'web_text' in st.session_state:
|
775 |
+
with st.expander("🤖 View the **Website content**"):
|
776 |
+
st.write(st.session_state['web_text'])
|
777 |
+
if 'old_db' in st.session_state:
|
778 |
+
with st.expander("🗂 View your **saved VectorStore**"):
|
779 |
+
st.success("📚 VectorStore loaded")
|
780 |
+
if 'god_mode_source' in st.session_state:
|
781 |
+
with st.expander("🌍 View source"):
|
782 |
+
for s in st.session_state['god_mode_source']:
|
783 |
+
st.markdown("- " + s)
|
784 |
+
|
785 |
+
# Response output
|
786 |
+
## Function for taking user prompt as input followed by producing AI generated responses
|
787 |
+
def generate_response(prompt):
|
788 |
+
final_prompt = ""
|
789 |
+
make_better = True
|
790 |
+
source = ""
|
791 |
+
|
792 |
+
with loading_container:
|
793 |
+
|
794 |
+
# FOR DEVELOPMENT PLUGIN
|
795 |
+
# if st.session_state['plugin'] == "🔌 PLUGIN NAME" and 'PLUGIN DB' in st.session_state:
|
796 |
+
# with st.spinner('🚀 Using PLUGIN NAME...'):
|
797 |
+
# solution = st.session_state['PLUGIN DB']({"query": prompt})
|
798 |
+
# final_prompt = YourCustomPrompt(prompt, context)
|
799 |
+
|
800 |
+
|
801 |
+
if st.session_state['plugin'] == "📋 Talk with your DATA" and 'df' in st.session_state:
|
802 |
+
#get only last message
|
803 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
804 |
+
if prompt.find('python') != -1 or prompt.find('Code') != -1 or prompt.find('code') != -1 or prompt.find('Python') != -1:
|
805 |
+
with st.spinner('🚀 Using tool for python code...'):
|
806 |
+
solution = "\n```python\n"
|
807 |
+
solution += st.session_state['df'].sketch.howto(prompt, call_display=False)
|
808 |
+
solution += "\n```\n\n"
|
809 |
+
final_prompt = prompt4Code(prompt, context, solution)
|
810 |
+
else:
|
811 |
+
with st.spinner('🚀 Using tool to get information...'):
|
812 |
+
solution = st.session_state['df'].sketch.ask(prompt, call_display=False)
|
813 |
+
final_prompt = prompt4Data(prompt, context, solution)
|
814 |
+
|
815 |
+
|
816 |
+
elif st.session_state['plugin'] == "📝 Talk with your DOCUMENTS" and 'pdf' in st.session_state:
|
817 |
+
#get only last message
|
818 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
819 |
+
with st.spinner('🚀 Using tool to get information...'):
|
820 |
+
result = st.session_state['pdf']({"query": prompt})
|
821 |
+
solution = result["result"]
|
822 |
+
if len(solution.split()) > 110:
|
823 |
+
make_better = False
|
824 |
+
final_prompt = solution
|
825 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
826 |
+
final_prompt += "\n\n✅Source:\n"
|
827 |
+
for d in result["source_documents"]:
|
828 |
+
final_prompt += "- " + str(d) + "\n"
|
829 |
+
else:
|
830 |
+
final_prompt = prompt4Context(prompt, context, solution)
|
831 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
832 |
+
source += "\n\n✅Source:\n"
|
833 |
+
for d in result["source_documents"]:
|
834 |
+
source += "- " + str(d) + "\n"
|
835 |
+
|
836 |
+
|
837 |
+
elif st.session_state['plugin'] == "🧠 GOD MODE" and 'god_mode' in st.session_state:
|
838 |
+
#get only last message
|
839 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
840 |
+
with st.spinner('🚀 Using tool to get information...'):
|
841 |
+
result = st.session_state['god_mode']({"query": prompt})
|
842 |
+
solution = result["result"]
|
843 |
+
if len(solution.split()) > 110:
|
844 |
+
make_better = False
|
845 |
+
final_prompt = solution
|
846 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
847 |
+
final_prompt += "\n\n✅Source:\n"
|
848 |
+
for d in result["source_documents"]:
|
849 |
+
final_prompt += "- " + str(d) + "\n"
|
850 |
+
else:
|
851 |
+
final_prompt = prompt4Context(prompt, context, solution)
|
852 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
853 |
+
source += "\n\n✅Source:\n"
|
854 |
+
for d in result["source_documents"]:
|
855 |
+
source += "- " + str(d) + "\n"
|
856 |
+
|
857 |
+
|
858 |
+
elif st.session_state['plugin'] == "🔗 Talk with Website" and 'web_sites' in st.session_state:
|
859 |
+
#get only last message
|
860 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
861 |
+
with st.spinner('🚀 Using tool to get information...'):
|
862 |
+
result = st.session_state['web_sites']({"query": prompt})
|
863 |
+
solution = result["result"]
|
864 |
+
if len(solution.split()) > 110:
|
865 |
+
make_better = False
|
866 |
+
final_prompt = solution
|
867 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
868 |
+
final_prompt += "\n\n✅Source:\n"
|
869 |
+
for d in result["source_documents"]:
|
870 |
+
final_prompt += "- " + str(d) + "\n"
|
871 |
+
else:
|
872 |
+
final_prompt = prompt4Context(prompt, context, solution)
|
873 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
874 |
+
source += "\n\n✅Source:\n"
|
875 |
+
for d in result["source_documents"]:
|
876 |
+
source += "- " + str(d) + "\n"
|
877 |
+
|
878 |
+
|
879 |
+
|
880 |
+
elif st.session_state['plugin'] == "💾 Upload saved VectorStore" and 'old_db' in st.session_state:
|
881 |
+
#get only last message
|
882 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
883 |
+
with st.spinner('🚀 Using tool to get information...'):
|
884 |
+
result = st.session_state['old_db']({"query": prompt})
|
885 |
+
solution = result["result"]
|
886 |
+
if len(solution.split()) > 110:
|
887 |
+
make_better = False
|
888 |
+
final_prompt = solution
|
889 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
890 |
+
final_prompt += "\n\n✅Source:\n"
|
891 |
+
for d in result["source_documents"]:
|
892 |
+
final_prompt += "- " + str(d) + "\n"
|
893 |
+
else:
|
894 |
+
final_prompt = prompt4Context(prompt, context, solution)
|
895 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
896 |
+
source += "\n\n✅Source:\n"
|
897 |
+
for d in result["source_documents"]:
|
898 |
+
source += "- " + str(d) + "\n"
|
899 |
+
|
900 |
+
|
901 |
+
elif st.session_state['plugin'] == "🎧 Talk with your AUDIO" and 'audio' in st.session_state:
|
902 |
+
#get only last message
|
903 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
904 |
+
with st.spinner('🚀 Using tool to get information...'):
|
905 |
+
result = st.session_state['audio']({"query": prompt})
|
906 |
+
solution = result["result"]
|
907 |
+
if len(solution.split()) > 110:
|
908 |
+
make_better = False
|
909 |
+
final_prompt = solution
|
910 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
911 |
+
final_prompt += "\n\n✅Source:\n"
|
912 |
+
for d in result["source_documents"]:
|
913 |
+
final_prompt += "- " + str(d) + "\n"
|
914 |
+
else:
|
915 |
+
final_prompt = prompt4Audio(prompt, context, solution)
|
916 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
917 |
+
source += "\n\n✅Source:\n"
|
918 |
+
for d in result["source_documents"]:
|
919 |
+
source += "- " + str(d) + "\n"
|
920 |
+
|
921 |
+
|
922 |
+
elif st.session_state['plugin'] == "🎥 Talk with YT video" and 'yt' in st.session_state:
|
923 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
924 |
+
with st.spinner('🚀 Using tool to get information...'):
|
925 |
+
result = st.session_state['yt']({"query": prompt})
|
926 |
+
solution = result["result"]
|
927 |
+
if len(solution.split()) > 110:
|
928 |
+
make_better = False
|
929 |
+
final_prompt = solution
|
930 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
931 |
+
final_prompt += "\n\n✅Source:\n"
|
932 |
+
for d in result["source_documents"]:
|
933 |
+
final_prompt += "- " + str(d) + "\n"
|
934 |
+
else:
|
935 |
+
final_prompt = prompt4YT(prompt, context, solution)
|
936 |
+
if 'source_documents' in result and len(result["source_documents"]) > 0:
|
937 |
+
source += "\n\n✅Source:\n"
|
938 |
+
for d in result["source_documents"]:
|
939 |
+
source += "- " + str(d) + "\n"
|
940 |
+
|
941 |
+
|
942 |
+
else:
|
943 |
+
#get last message if exists
|
944 |
+
if len(st.session_state['past']) == 1:
|
945 |
+
context = f"User: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
946 |
+
else:
|
947 |
+
context = f"User: {st.session_state['past'][-2]}\nBot: {st.session_state['generated'][-2]}\nUser: {st.session_state['past'][-1]}\nBot: {st.session_state['generated'][-1]}\n"
|
948 |
+
|
949 |
+
if 'web_search' in st.session_state:
|
950 |
+
if st.session_state['web_search'] == "True":
|
951 |
+
with st.spinner('🚀 Using internet to get information...'):
|
952 |
+
internet_result = ""
|
953 |
+
internet_answer = ""
|
954 |
+
with DDGS() as ddgs:
|
955 |
+
ddgs_gen = ddgs.text(prompt, region=st.session_state['region'], safesearch=st.session_state['safesearch'], timelimit=st.session_state['timelimit'])
|
956 |
+
for r in islice(ddgs_gen, st.session_state['max_results']):
|
957 |
+
internet_result += str(r) + "\n\n"
|
958 |
+
fast_answer = ddgs.answers(prompt)
|
959 |
+
for r in islice(fast_answer, 2):
|
960 |
+
internet_answer += str(r) + "\n\n"
|
961 |
+
|
962 |
+
final_prompt = prompt4conversationInternet(prompt, context, internet_result, internet_answer)
|
963 |
+
else:
|
964 |
+
final_prompt = prompt4conversation(prompt, context)
|
965 |
+
else:
|
966 |
+
final_prompt = prompt4conversation(prompt, context)
|
967 |
+
|
968 |
+
if make_better:
|
969 |
+
with st.spinner('🚀 Generating response...'):
|
970 |
+
print(final_prompt)
|
971 |
+
response = st.session_state['chatbot'].chat(final_prompt, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, top_k=top_k, max_new_tokens=max_new_tokens)
|
972 |
+
response += source
|
973 |
+
else:
|
974 |
+
print(final_prompt)
|
975 |
+
response = final_prompt
|
976 |
+
|
977 |
+
return response
|
978 |
+
|
979 |
+
## Conditional display of AI generated responses as a function of user provided prompts
|
980 |
+
with response_container:
|
981 |
+
if input_text and 'hf_email' in st.session_state and 'hf_pass' in st.session_state:
|
982 |
+
response = generate_response(input_text)
|
983 |
+
st.session_state.past.append(input_text)
|
984 |
+
st.session_state.generated.append(response)
|
985 |
+
|
986 |
+
|
987 |
+
#print message in normal order, frist user then bot
|
988 |
+
if 'generated' in st.session_state:
|
989 |
+
if st.session_state['generated']:
|
990 |
+
for i in range(len(st.session_state['generated'])):
|
991 |
+
with st.chat_message(name="user"):
|
992 |
+
st.markdown(st.session_state['past'][i])
|
993 |
+
|
994 |
+
with st.chat_message(name="assistant"):
|
995 |
+
if len(st.session_state['generated'][i].split("✅Source:")) > 1:
|
996 |
+
source = st.session_state['generated'][i].split("✅Source:")[1]
|
997 |
+
mess = st.session_state['generated'][i].split("✅Source:")[0]
|
998 |
+
|
999 |
+
st.markdown(mess)
|
1000 |
+
with st.expander("📚 Source of message number " + str(i+1)):
|
1001 |
+
st.markdown(source)
|
1002 |
+
|
1003 |
+
else:
|
1004 |
+
st.markdown(st.session_state['generated'][i])
|
1005 |
+
|
1006 |
+
st.markdown('', unsafe_allow_html=True)
|
1007 |
+
|
1008 |
+
|
1009 |
+
else:
|
1010 |
+
st.info("👋 Hey , we are very happy to see you here 🤗")
|
1011 |
+
st.info("👉 Please Login to continue, click on top left corner to login 🚀")
|
1012 |
+
st.error("👉 If you are not registered on Hugging Face, please register first and then login 🤗")
|