openfree commited on
Commit
f687683
·
verified ·
1 Parent(s): 95488d7

Update src/main.py

Browse files
Files changed (1) hide show
  1. src/main.py +16 -11
src/main.py CHANGED
@@ -2,33 +2,38 @@ import display_gloss as dg
2
  import synonyms_preprocess as sp
3
  from NLP_Spacy_base_translator import NlpSpacyBaseTranslator
4
  from flask import Flask, render_template, Response, request
5
- from transformers import pipeline
6
  import torch
7
  import os
8
 
9
  app = Flask(__name__)
10
  app.config['TITLE'] = 'ASL Translator'
11
 
12
- # Set cache directory
13
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
14
- os.makedirs('/tmp/transformers_cache', exist_ok=True)
 
 
 
15
 
16
- # Force CPU usage
17
  device = torch.device('cpu')
18
  os.environ['CUDA_VISIBLE_DEVICES'] = ''
19
 
20
- # Initialize translator with local cache
21
- translator = pipeline("translation",
22
- model="Helsinki-NLP/opus-mt-ko-en",
23
- device=device,
24
- model_kwargs={"cache_dir": "/tmp/transformers_cache"})
25
 
26
  nlp, dict_docs_spacy = sp.load_spacy_values()
27
  dataset, list_2000_tokens = dg.load_data()
28
 
29
  def translate_korean_to_english(text):
30
  if any('\u3131' <= char <= '\u318F' or '\uAC00' <= char <= '\uD7A3' for char in text):
31
- translation = translator(text)[0]['translation_text']
 
 
32
  return translation
33
  return text
34
 
 
2
  import synonyms_preprocess as sp
3
  from NLP_Spacy_base_translator import NlpSpacyBaseTranslator
4
  from flask import Flask, render_template, Response, request
5
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqGeneration
6
  import torch
7
  import os
8
 
9
  app = Flask(__name__)
10
  app.config['TITLE'] = 'ASL Translator'
11
 
12
+ # 캐시 디렉토리 설정
13
+ cache_dir = "/tmp/huggingface"
14
+ if not os.path.exists(cache_dir):
15
+ os.makedirs(cache_dir, exist_ok=True)
16
+ os.environ['TRANSFORMERS_CACHE'] = cache_dir
17
+ os.environ['HF_HOME'] = cache_dir
18
 
19
+ # CPU 설정
20
  device = torch.device('cpu')
21
  os.environ['CUDA_VISIBLE_DEVICES'] = ''
22
 
23
+ # 번역 모델 초기화
24
+ model_name = "Helsinki-NLP/opus-mt-ko-en"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
26
+ model = AutoModelForSeq2SeqGeneration.from_pretrained(model_name, cache_dir=cache_dir)
27
+ model = model.to(device)
28
 
29
  nlp, dict_docs_spacy = sp.load_spacy_values()
30
  dataset, list_2000_tokens = dg.load_data()
31
 
32
  def translate_korean_to_english(text):
33
  if any('\u3131' <= char <= '\u318F' or '\uAC00' <= char <= '\uD7A3' for char in text):
34
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
35
+ outputs = model.generate(**inputs)
36
+ translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
  return translation
38
  return text
39