Florin Bobiș commited on
Commit
9f814ca
·
1 Parent(s): 9b7056a
Files changed (7) hide show
  1. .gitignore +2 -0
  2. app.py +51 -0
  3. index.py +0 -17
  4. package.json +0 -5
  5. public/favicon.ico +0 -0
  6. requirements.txt +3 -3
  7. test.http +0 -5
.gitignore CHANGED
@@ -159,3 +159,5 @@ cython_debug/
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
  .vercel
 
 
 
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
  .vercel
162
+
163
+ cache/
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import MT5ForConditionalGeneration, T5Tokenizer
3
+ import time
4
+
5
+ @st.cache_resource
6
+ def load_model():
7
+ model = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics', cache_dir='cache/')
8
+ return model
9
+
10
+ @st.cache_resource
11
+ def load_tokenizer():
12
+ tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics', legacy=False, cache_dir='cache/')
13
+ return tokenizer
14
+
15
+ def initialize_app():
16
+ st.set_page_config(
17
+ page_title="Dia-critic",
18
+ page_icon="public/favicon.ico",
19
+ menu_items={
20
+ "About": "### Contact\n ✉️[email protected]",
21
+ },
22
+ )
23
+ st.title("🖋️Dia-critic")
24
+ st.caption("Made with :heart: by NEBO Technologies")
25
+
26
+ def generate_text(text):
27
+ model = load_model()
28
+ tokenizer = load_tokenizer()
29
+ inputs = tokenizer(text, max_length=256, truncation=True, return_tensors="pt")
30
+ outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
31
+ output = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+ return output
33
+
34
+ def main():
35
+ initialize_app()
36
+
37
+ input_text = st.text_area("Introduceți textul mai jos")
38
+ st.write(f'{len(input_text)} caractere.')
39
+ if st.button("Corectează"):
40
+ if input_text != "":
41
+ res = ''
42
+ with st.spinner('Sarcină în desfășurare...'):
43
+ # start task
44
+ res = generate_text(input_text)
45
+ with st.container(border=True):
46
+ st.markdown(res)
47
+ else:
48
+ st.warning("Câmpul este gol!")
49
+
50
+ if __name__ == "__main__":
51
+ main()
index.py DELETED
@@ -1,17 +0,0 @@
1
- from flask import Flask, request, jsonify
2
- from transformers import MT5ForConditionalGeneration, T5Tokenizer
3
-
4
- app = Flask(__name__)
5
-
6
- model = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics', cache_dir='cache/')
7
- tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics', legacy=False, cache_dir='cache/')
8
-
9
- @app.route('/generate', methods=['POST'])
10
- def generate_text():
11
- input_text = request.get_json()['input_text']
12
- if input_text is None:
13
- return jsonify({'error': 'No input text provided'})
14
- inputs = tokenizer(input_text, max_length=256, truncation=True, return_tensors="pt")
15
- outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
16
- output = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
- return jsonify({'output': output})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
package.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "engines": {
3
- "node": "18.x"
4
- }
5
- }
 
 
 
 
 
 
public/favicon.ico ADDED
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- Flask==3.0.0
2
- torch
3
  torchvision
4
- transformers
 
1
+ streamlit
2
+ transformers==4.27.4
3
  torchvision
4
+ sentencepiece
test.http DELETED
@@ -1,5 +0,0 @@
1
- POST http://94.101.98.71:5000/generate
2
-
3
- {
4
- "input_text": "cat de multe as vrea sa-ti spun tie"
5
- }