BERT的中文问答系统14
项目的目录结构,以及每个文件和文件夹的作用说明:
code
project/
│
├── icons/ # 存放图标文件
│ ├── xihe.png
│ └── zero.png
│
├── data/ # 存放数据文件
│ └── train_data.jsonl
│
├── logs/ # 存放日志文件
│
├── models/ # 存放模型文件
│ └── xihua_model.pth
│
├── main.py # 主入口文件
├── requirements.txt # 依赖文件
└── README.md # 项目说明文件
详细说明
项目根目录 (
project/
)
icons/: 存放图标文件,例如 xihe.png 和 zero.png。
data/: 存放训练数据文件,例如 train_data.jsonl。
logs/: 存放日志文件。
models/: 存放训练好的模型文件,例如 xihua_model.pth。
main.py: 项目的主入口文件。
requirements.txt: 项目依赖文件,列出所有需要的Python包及其版本。
README.md: 项目说明文件,介绍项目的目的、如何运行等信息。
文件内容示例
main.py
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))os.makedirs(os.path.dirname(log_file), exist_ok=True)logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:if self.validate_item(item):data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = [item for item in json.load(f) if self.validate_item(item)]except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef validate_item(self, item):required_keys = ['question', 'human_answers', 'chatgpt_answers']if all(key in item for key in required_keys):return Truelogging.warning(f"跳过无效项: 缺少必要键 {required_keys}")return Falsedef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', pa