Commit a3452689 authored by Nils König's avatar Nils König
Browse files

[refactor, #58] replace model in question generator

parent eb67e728
Pipeline #75720 passed with stages
in 3 minutes and 7 seconds
......@@ -72,8 +72,8 @@ async def exception_handler(request: Request, exception: Exception):
@app.post("/questionGenerator")
async def api_questionGenerator(item: Item):
qg = QuestionGenerator()
async def api_question_generator(item: Item):
qg = QuestionGenerator(_nlp=nlp)
return qg.generate(
item.text,
num_questions=item.num_questions,
......@@ -89,11 +89,11 @@ async def test(item2: Item2):
force_overwriting=True
)
facade.run()
returnJson = open(os.path.join(
return_json = open(os.path.join(
"./app/output/train/output.json"),
"r"
).read()
return json.loads(returnJson)
return json.loads(return_json)
@app.get("/")
......
......@@ -7,7 +7,6 @@ import torch
import re
import random
import json
import en_core_web_sm
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
......@@ -16,9 +15,9 @@ from transformers import (
class QuestionGenerator:
def __init__(self, model_dir=None):
def __init__(self, _nlp, model_dir=None):
QG_PRETRAINED = "iarfmoose/t5-base-question-generator"
qg_pretrained = "iarfmoose/t5-base-question-generator"
self.ANSWER_TOKEN = "<answer>"
self.CONTEXT_TOKEN = "<context>"
self.SEQ_LENGTH = 512
......@@ -27,13 +26,15 @@ class QuestionGenerator:
torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.qg_tokenizer = \
AutoTokenizer.from_pretrained(QG_PRETRAINED, use_fast=False)
AutoTokenizer.from_pretrained(qg_pretrained, use_fast=False)
self.qg_model = \
AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
AutoModelForSeq2SeqLM.from_pretrained(qg_pretrained)
self.qg_model.to(self.device)
self.qa_evaluator = QAEvaluator(model_dir)
self._nlp = _nlp
def generate(
self,
article,
......@@ -79,12 +80,12 @@ class QuestionGenerator:
def generate_qg_inputs(self, text, answer_style):
VALID_ANSWER_STYLES = ["all", "sentences", "multiple_choice"]
valid_answer_styles = ["all", "sentences", "multiple_choice"]
if answer_style not in VALID_ANSWER_STYLES:
if answer_style not in valid_answer_styles:
raise ValueError(
"Invalid answer style {}. Please choose from {}".format(
answer_style, VALID_ANSWER_STYLES
answer_style, valid_answer_styles
)
)
......@@ -104,7 +105,7 @@ class QuestionGenerator:
if answer_style == "multiple_choice" or answer_style == "all":
sentences = self._split_text(text)
prepped_inputs, prepped_answers = \
self._prepare_qg_inputs_MC(sentences)
self._prepare_qg_inputs_mc(sentences)
inputs.extend(prepped_inputs)
answers.extend(prepped_answers)
......@@ -119,14 +120,15 @@ class QuestionGenerator:
return generated_questions
def _split_text(self, text):
MAX_SENTENCE_LEN = 128
@staticmethod
def _split_text(text):
max_sentence_len = 128
sentences = re.findall(r".*?[.!\?]", text)
cut_sentences = []
for sentence in sentences:
if len(sentence) > MAX_SENTENCE_LEN:
if len(sentence) > max_sentence_len:
cut_sentences.extend(re.split("[,;:)]", sentence))
# temporary solution to remove useless post-quote sentence fragments
cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
......@@ -135,7 +137,7 @@ class QuestionGenerator:
return list(set([s.strip(" ") for s in sentences]))
def _split_into_segments(self, text):
MAX_TOKENS = 490
max_tokens = 490
paragraphs = text.split("\n")
tokenized_paragraphs = [
......@@ -145,7 +147,7 @@ class QuestionGenerator:
segments = []
while len(tokenized_paragraphs) > 0:
segment = []
while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
while len(segment) < max_tokens and len(tokenized_paragraphs) > 0:
paragraph = tokenized_paragraphs.pop(0)
segment.extend(paragraph)
segments.append(segment)
......@@ -164,10 +166,9 @@ class QuestionGenerator:
return inputs, answers
def _prepare_qg_inputs_MC(self, sentences):
def _prepare_qg_inputs_mc(self, sentences):
spacy_nlp = en_core_web_sm.load()
docs = list(spacy_nlp.pipe(sentences, disable=["parser"]))
docs = list(self._nlp.pipe(sentences, disable=["parser"]))
inputs_from_text = []
answers_from_text = []
......@@ -181,13 +182,14 @@ class QuestionGenerator:
self.CONTEXT_TOKEN,
sentences[i]
)
answers = self._get_MC_answers(entity, docs)
answers = self._get_mc_answers(entity, docs)
inputs_from_text.append(qg_input)
answers_from_text.append(answers)
return inputs_from_text, answers_from_text
def _get_MC_answers(self, correct_answer, docs):
@staticmethod
def _get_mc_answers(correct_answer, docs):
entities = []
for doc in docs:
......@@ -286,25 +288,24 @@ class QuestionGenerator:
qa_list.append(qa)
return qa_list
def _make_dict(self, question, answer):
qa = {}
qa["question"] = question
qa["answer"] = answer
@staticmethod
def _make_dict(question, answer):
qa = {"question": question, "answer": answer}
return qa
class QAEvaluator:
def __init__(self, model_dir=None):
QAE_PRETRAINED = "iarfmoose/bert-base-cased-qa-evaluator"
qae_pretrained = "iarfmoose/bert-base-cased-qa-evaluator"
self.SEQ_LENGTH = 512
self.device = torch.\
device("cuda" if torch.cuda.is_available() else "cpu")
self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
self.qae_tokenizer = AutoTokenizer.from_pretrained(qae_pretrained)
self.qae_model = AutoModelForSequenceClassification.from_pretrained(
QAE_PRETRAINED
qae_pretrained
)
self.qae_model.to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment